Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove some dask dependencies #156

Merged
merged 5 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import tlz
from dask.array.core import broadcast_chunks, normalize_chunks
from dask.array.reshape import reshape_rechunk
from dask.array.slicing import sanitize_index
from dask.array.utils import validate_axis
from toolz import reduce

Expand Down Expand Up @@ -185,15 +184,14 @@ def permute_dims(x, /, axes):
def reshape(x, /, shape):
# based on dask reshape

shape = tuple(map(sanitize_index, shape))
known_sizes = [s for s in shape if s != -1]
if len(known_sizes) != len(shape):
if len(shape) - len(known_sizes) > 1:
raise ValueError("can only specify one unknown dimension")
# Fastpath for x.reshape(-1) on 1D arrays
if len(shape) == 1 and x.ndim == 1:
return x
missing_size = sanitize_index(x.size / reduce(mul, known_sizes, 1))
missing_size = x.size // reduce(mul, known_sizes, 1)
shape = tuple(missing_size if s == -1 else s for s in shape)

if reduce(mul, shape, 1) != x.size:
Expand Down
36 changes: 34 additions & 2 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math

import numpy as np
from dask.array.reductions import numel

from cubed.array_api.dtypes import (
_numeric_dtypes,
Expand Down Expand Up @@ -37,7 +38,7 @@ def mean(x, /, *, axis=None, keepdims=False):


def _mean_func(a, **kwargs):
n = numel(a, **kwargs)
n = _numel(a, **kwargs)
total = np.sum(a, **kwargs)
return {"n": n, "total": total}

Expand All @@ -52,6 +53,37 @@ def _mean_aggregate(a):
return np.divide(a["total"], a["n"])


# based on dask
def _numel(x, **kwargs):
"""
A reduction to count the number of elements.
"""
shape = x.shape
keepdims = kwargs.get("keepdims", False)
axis = kwargs.get("axis", None)
dtype = kwargs.get("dtype", np.float64)

if axis is None:
prod = np.prod(shape, dtype=dtype)
if keepdims is False:
return prod

return np.full(shape=(1,) * len(shape), fill_value=prod, dtype=dtype)

if not isinstance(axis, (tuple, list)):
axis = [axis]

prod = math.prod(shape[dim] for dim in axis)
if keepdims is True:
new_shape = tuple(
shape[dim] if dim not in axis else 1 for dim in range(len(shape))
)
else:
new_shape = tuple(shape[dim] for dim in range(len(shape)) if dim not in axis)

return np.broadcast_to(np.array(prod, dtype=dtype), new_shape)


def min(x, /, *, axis=None, keepdims=False):
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in min")
Expand Down
10 changes: 5 additions & 5 deletions cubed/core/gufunc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from dask.array.gufunc import _parse_gufunc_signature, _validate_normalize_axes
from dask.array.gufunc import _parse_gufunc_signature
from tlz import concat, merge, unique


Expand Down Expand Up @@ -30,7 +30,7 @@ def apply_gufunc(
"""

# Currently the following parameters cannot be changed
keepdims = False
# keepdims = False
allow_rechunk = False

# based on dask's apply_gufunc
Expand Down Expand Up @@ -60,9 +60,9 @@ def apply_gufunc(
output_sizes = {}

# Axes
input_axes, output_axes = _validate_normalize_axes(
axes, axis, keepdims, input_coredimss, output_coredimss
)
# input_axes, output_axes = _validate_normalize_axes(
# axes, axis, keepdims, input_coredimss, output_coredimss
# )

# Main code:

Expand Down
10 changes: 3 additions & 7 deletions cubed/runtime/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import math
from typing import Any, Iterable, Iterator, List, Tuple

import dask
import numpy as np

from cubed.primitive.types import CubedPipeline
Expand Down Expand Up @@ -43,26 +42,23 @@ def copy_read_to_write(chunk_key, *, config=CopySpec):
# workaround limitation of lithops.utils.verify_args
if isinstance(chunk_key, list):
chunk_key = tuple(chunk_key)
with dask.config.set(scheduler="single-threaded"):
data = np.asarray(config.read.array[chunk_key])
data = np.asarray(config.read.array[chunk_key])
config.write.array[chunk_key] = data


def copy_read_to_intermediate(chunk_key, *, config=CopySpec):
# workaround limitation of lithops.utils.verify_args
if isinstance(chunk_key, list):
chunk_key = tuple(chunk_key)
with dask.config.set(scheduler="single-threaded"):
data = np.asarray(config.read.array[chunk_key])
data = np.asarray(config.read.array[chunk_key])
config.intermediate.array[chunk_key] = data


def copy_intermediate_to_write(chunk_key, *, config=CopySpec):
# workaround limitation of lithops.utils.verify_args
if isinstance(chunk_key, list):
chunk_key = tuple(chunk_key)
with dask.config.set(scheduler="single-threaded"):
data = np.asarray(config.intermediate.array[chunk_key])
data = np.asarray(config.intermediate.array[chunk_key])
config.write.array[chunk_key] = data


Expand Down
13 changes: 11 additions & 2 deletions cubed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import traceback
from dataclasses import dataclass
from math import prod
from operator import add
from pathlib import Path
from posixpath import join
from resource import RUSAGE_SELF, getrusage
from typing import Union
from urllib.parse import quote, unquote, urlsplit, urlunsplit

import numpy as np
import tlz as toolz
from dask.array.core import _check_regular_chunks
from dask.utils import cached_cumsum

PathType = Union[str, Path]

Expand All @@ -26,11 +27,19 @@ def chunk_memory(dtype, chunksize):

def get_item(chunks, idx):
"""Convert a chunk index to a tuple of slices."""
starts = tuple(cached_cumsum(c, initial_zero=True) for c in chunks)
# could use Dask's cached_cumsum here if it improves performance
starts = tuple(_cumsum(c, initial_zero=True) for c in chunks)
loc = tuple((start[i], start[i + 1]) for i, start in zip(idx, starts))
return tuple(slice(*s, None) for s in loc)


def _cumsum(seq, initial_zero=False):
if initial_zero:
return tuple(toolz.accumulate(add, seq, 0))
else:
return tuple(toolz.accumulate(add, seq))


def join_path(dir_url: PathType, child_path: str) -> str:
"""Combine a URL for a directory with a child path"""
parts = urlsplit(str(dir_url))
Expand Down