Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Feb 1, 2023
1 parent c04c580 commit be5c5ad
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 63 deletions.
73 changes: 54 additions & 19 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dask.layers import Layer
from dask.utils import stringify, stringify_collection_keys

from distributed.protocol.serialize import to_serialize
from distributed.shuffle._shuffle import (
ShuffleId,
_get_worker_extension,
Expand Down Expand Up @@ -65,13 +66,15 @@ def hash_join_p2p(
_lhs_meta = lhs._meta_nonempty if len(lhs.columns) else lhs._meta
_rhs_meta = rhs._meta_nonempty if len(rhs.columns) else rhs._meta
meta = _lhs_meta.merge(_rhs_meta, **merge_kwargs)
merge_name = "hash-join-" + tokenize(lhs, left_on, lhs, right_on, suffixes)
merge_name = "hash-join-" + tokenize(lhs, rhs, **merge_kwargs)
join_layer = HashJoinP2PLayer(
name=merge_name,
name_input_left=lhs._name,
left_on=left_on,
n_partitions_left=lhs.npartitions,
name_input_right=rhs._name,
right_on=right_on,
n_partitions_right=rhs.npartitions,
meta_output=meta,
how=how,
npartitions=npartitions,
Expand All @@ -86,6 +89,8 @@ def hash_join_p2p(

hash_join = hash_join_p2p

_HASH_COLUMN_NAME = "__partition"


def merge_transfer(
input: pd.DataFrame,
Expand All @@ -94,14 +99,13 @@ def merge_transfer(
npartitions: int,
column: str,
):
hash_column = "__partition"
input[hash_column] = partitioning_index(input[column], npartitions)
input[_HASH_COLUMN_NAME] = partitioning_index(input[column], npartitions)
return shuffle_transfer(
input=input,
id=id,
input_partition=input_partition,
npartitions=npartitions,
column=hash_column,
column=_HASH_COLUMN_NAME,
)


Expand All @@ -111,23 +115,34 @@ def merge_unpack(
output_partition: int,
barrier_left: int,
barrier_right: int,
how,
left_on,
right_on,
result_meta,
suffixes,
):
# FIXME: This is odd. There are similar things happening in dask/dask
# layers.py but this works for now
from dask.dataframe.multi import merge_chunk

from distributed.protocol import deserialize

# FIXME: This is odd.
result_meta = deserialize(result_meta.header, result_meta.frames)
ext = _get_worker_extension()
left = ext.get_output_partition(
shuffle_id_left, barrier_left, output_partition
).drop(columns="__partition")
).drop(columns=_HASH_COLUMN_NAME)
right = ext.get_output_partition(
shuffle_id_right, barrier_right, output_partition
).drop(columns="__partition")
return merge_chunk(left, right, result_meta=result_meta)
).drop(columns=_HASH_COLUMN_NAME)
return merge_chunk(
left,
right,
how=how,
result_meta=result_meta,
left_on=left_on,
right_on=right_on,
suffixes=suffixes,
)


class HashJoinP2PLayer(Layer):
Expand All @@ -136,6 +151,8 @@ def __init__(
name: str,
name_input_left: str,
left_on,
n_partitions_left: int,
n_partitions_right: int,
name_input_right: str,
right_on,
meta_output: pd.DataFrame,
Expand All @@ -157,6 +174,8 @@ def __init__(
self.indicator = indicator
self.meta_output = meta_output
self.parts_out = parts_out or list(range(npartitions))
self.n_partitions_left = n_partitions_left
self.n_partitions_right = n_partitions_right
annotations = annotations or {}
# TODO: This is more complex
annotations.update({"shuffle": lambda key: key[-1]})
Expand All @@ -169,7 +188,6 @@ def _cull_dependencies(self, keys, parts_out=None):
all input partitions. This method does not require graph
materialization.
"""
# FIXME: I believe this is just wrong. For P2PLayer as well
deps = defaultdict(set)
parts_out = parts_out or self._keys_to_parts(keys)
for part in parts_out:
Expand Down Expand Up @@ -236,6 +254,8 @@ def _cull(self, parts_out):
meta_output=self.meta_output,
parts_out=parts_out,
annotations=self.annotations,
n_partitions_left=self.n_partitions_left,
n_partitions_right=self.n_partitions_right,
)

def cull(self, keys, all_keys):
Expand All @@ -255,12 +275,18 @@ def cull(self, keys, all_keys):
return self, culled_deps

def _construct_graph(self) -> dict[tuple | str, tuple]:
token_left = tokenize(
self.name_input_left, self.left_on, self.npartitions, self.parts_out
)
token_right = tokenize(
self.name_input_right, self.right_on, self.npartitions, self.parts_out
args = (
self.left_on,
self.how,
self.npartitions,
self.n_partitions_left,
self.n_partitions_right,
self.parts_out,
self.suffixes,
self.indicator,
)
token_left = tokenize(self.name_input_left, *args)
token_right = tokenize(self.name_input_right, *args)
dsk: dict[tuple | str, tuple] = {}
# FIXME: This is a problem. The barrier key is parsed to infer the
# shuffle ID that is currently used in the transition hook.
Expand All @@ -270,7 +296,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
name_right = "hash-join-transfer-" + token_right
transfer_keys_left = list()
transfer_keys_right = list()
for i in range(self.npartitions):
for i in range(self.n_partitions_left):
transfer_keys_left.append((name_left, i))
dsk[(name_left, i)] = (
merge_transfer,
Expand All @@ -280,6 +306,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
self.npartitions,
self.left_on,
)
for i in range(self.n_partitions_right):
transfer_keys_right.append((name_right, i))
dsk[(name_right, i)] = (
merge_transfer,
Expand All @@ -304,7 +331,11 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
part_out,
_barrier_key_left,
_barrier_key_right,
self.how,
self.left_on,
self.right_on,
self.meta_output,
self.suffixes,
)
return dsk

Expand All @@ -314,6 +345,7 @@ def __dask_distributed_unpack__(cls, state, dsk, dependecies):

to_list = [
"left_on",
"suffixes",
"right_on",
]
# msgpack will convert lists into tuples, here
Expand All @@ -331,14 +363,15 @@ def __dask_distributed_unpack__(cls, state, dsk, dependecies):
}
keys = layer_dsk.keys() | dsk.keys()
# TODO: use shuffle-knowledge to calculate dependencies more efficiently
deps = {k: keys_in_tasks(keys, [v]) for k, v in layer_dsk.items()}

deps = {}
for k, v in layer_dsk.items():
deps[k] = d = keys_in_tasks(keys, [v])
assert d
return {"dsk": toolz.valmap(dumps_task, layer_dsk), "deps": deps}

def __dask_distributed_pack__(
self, all_hlg_keys, known_key_dependencies, client, client_keys
):
from distributed.protocol.serialize import to_serialize

return {
"name": self.name,
Expand All @@ -352,4 +385,6 @@ def __dask_distributed_pack__(
"suffixes": self.suffixes,
"indicator": self.indicator,
"parts_out": self.parts_out,
"n_partitions_left": self.n_partitions_left,
"n_partitions_right": self.n_partitions_right,
}
182 changes: 182 additions & 0 deletions distributed/shuffle/tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from __future__ import annotations

import pytest

from distributed.shuffle._merge import hash_join
from distributed.utils_test import gen_cluster

dd = pytest.importorskip("dask.dataframe")
import pandas as pd

from dask.dataframe._compat import tm
from dask.dataframe.utils import assert_eq
from dask.utils_test import hlg_layer_topological


def list_eq(aa, bb):
if isinstance(aa, dd.DataFrame):
a = aa.compute(scheduler="sync")
else:
a = aa
if isinstance(bb, dd.DataFrame):
b = bb.compute(scheduler="sync")
else:
b = bb
tm.assert_index_equal(a.columns, b.columns)

if isinstance(a, pd.DataFrame):
av = a.sort_values(list(a.columns)).values
bv = b.sort_values(list(b.columns)).values
else:
av = a.sort_values().values
bv = b.sort_values().values

dd._compat.assert_numpy_array_equal(av, bv)


@pytest.mark.parametrize("how", ["inner", "left", "right", "outer"])
@gen_cluster(client=True)
async def test_basic_merge(c, s, a, b, how):

A = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6], "y": [1, 1, 2, 2, 3, 4]})
a = dd.repartition(A, [0, 4, 5])

B = pd.DataFrame({"y": [1, 3, 4, 4, 5, 6], "z": [6, 5, 4, 3, 2, 1]})
b = dd.repartition(B, [0, 2, 5])

joined = hash_join(a, "y", b, "y", how)

assert not hlg_layer_topological(joined.dask, -1).is_materialized()
result = await c.compute(joined)
expected = pd.merge(A, B, how, "y")
list_eq(result, expected)

# Different columns and npartitions
joined = hash_join(a, "x", b, "z", "outer", npartitions=3)
assert not hlg_layer_topological(joined.dask, -1).is_materialized()
assert joined.npartitions == 3

result = await c.compute(joined)
expected = pd.merge(A, B, "outer", None, "x", "z")

list_eq(result, expected)

assert (
hash_join(a, "y", b, "y", "inner")._name
== hash_join(a, "y", b, "y", "inner")._name
)
assert (
hash_join(a, "y", b, "y", "inner")._name
!= hash_join(a, "y", b, "y", "outer")._name
)


@pytest.mark.parametrize("how", ["inner", "outer", "left", "right"])
@gen_cluster(client=True)
async def test_merge(c, s, a, b, how):
shuffle_method = "p2p"
A = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6], "y": [1, 1, 2, 2, 3, 4]})
a = dd.repartition(A, [0, 4, 5])

B = pd.DataFrame({"y": [1, 3, 4, 4, 5, 6], "z": [6, 5, 4, 3, 2, 1]})
b = dd.repartition(B, [0, 2, 5])

res = await c.compute(
dd.merge(
a, b, left_index=True, right_index=True, how=how, shuffle=shuffle_method
)
)
assert_eq(
res,
pd.merge(A, B, left_index=True, right_index=True, how=how),
)
joined = dd.merge(a, b, on="y", how=how)
result = await c.compute(joined)
list_eq(result, pd.merge(A, B, on="y", how=how))
assert all(d is None for d in joined.divisions)

list_eq(
await c.compute(
dd.merge(a, b, left_on="x", right_on="z", how=how, shuffle=shuffle_method)
),
pd.merge(A, B, left_on="x", right_on="z", how=how),
)
list_eq(
await c.compute(
dd.merge(
a,
b,
left_on="x",
right_on="z",
how=how,
suffixes=("1", "2"),
shuffle=shuffle_method,
)
),
pd.merge(A, B, left_on="x", right_on="z", how=how, suffixes=("1", "2")),
)

list_eq(
await c.compute(dd.merge(a, b, how=how, shuffle=shuffle_method)),
pd.merge(A, B, how=how),
)
list_eq(
await c.compute(dd.merge(a, B, how=how, shuffle=shuffle_method)),
pd.merge(A, B, how=how),
)
list_eq(
await c.compute(dd.merge(A, b, how=how, shuffle=shuffle_method)),
pd.merge(A, B, how=how),
)
# Note: No await since A and B are both pandas dataframes and this doesn't
# actually submit anything
list_eq(
c.compute(dd.merge(A, B, how=how, shuffle=shuffle_method)),
pd.merge(A, B, how=how),
)

list_eq(
await c.compute(
dd.merge(
a, b, left_index=True, right_index=True, how=how, shuffle=shuffle_method
)
),
pd.merge(A, B, left_index=True, right_index=True, how=how),
)
list_eq(
await c.compute(
dd.merge(
a,
b,
left_index=True,
right_index=True,
how=how,
suffixes=("1", "2"),
shuffle=shuffle_method,
)
),
pd.merge(A, B, left_index=True, right_index=True, how=how, suffixes=("1", "2")),
)

list_eq(
await c.compute(
dd.merge(
a, b, left_on="x", right_index=True, how=how, shuffle=shuffle_method
)
),
pd.merge(A, B, left_on="x", right_index=True, how=how),
)
list_eq(
await c.compute(
dd.merge(
a,
b,
left_on="x",
right_index=True,
how=how,
suffixes=("1", "2"),
shuffle=shuffle_method,
)
),
pd.merge(A, B, left_on="x", right_index=True, how=how, suffixes=("1", "2")),
)
Loading

0 comments on commit be5c5ad

Please sign in to comment.