Skip to content

Commit

Permalink
[BACKPORT] Do not aggressively choose tree method in tile of groupby …
Browse files Browse the repository at this point in the history
…for distributed setting (#3032) (#3070)
  • Loading branch information
Xuye (Chris) Qin authored May 23, 2022
1 parent f93ef28 commit 7d78538
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 57 deletions.
83 changes: 47 additions & 36 deletions benchmarks/tpch/run_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import argparse
import functools
import time
from typing import Callable
from typing import Callable, List, Optional, Set, Union

import mars
import mars.dataframe as md

queries = None
queries: Optional[Union[Set[str], List[str]]] = None


def load_lineitem(data_folder: str) -> md.DataFrame:
Expand Down Expand Up @@ -158,7 +158,8 @@ def q01(lineitem: md.DataFrame):
"L_ORDERKEY": "count",
}
)
total = total.sort_values(["L_RETURNFLAG", "L_LINESTATUS"])
# skip sort, Mars groupby enables sort
# total = total.sort_values(["L_RETURNFLAG", "L_LINESTATUS"])
print(total.execute())


Expand Down Expand Up @@ -238,7 +239,9 @@ def q02(part, partsupp, supplier, nation, region):
"P_MFGR",
],
]
min_values = merged_df.groupby("P_PARTKEY", as_index=False)["PS_SUPPLYCOST"].min()
min_values = merged_df.groupby("P_PARTKEY", as_index=False, sort=False)[
"PS_SUPPLYCOST"
].min()
min_values.columns = ["P_PARTKEY_CPY", "MIN_SUPPLYCOST"]
merged_df = merged_df.merge(
min_values,
Expand Down Expand Up @@ -286,9 +289,9 @@ def q03(lineitem, orders, customer):
jn2 = jn1.merge(flineitem, left_on="O_ORDERKEY", right_on="L_ORDERKEY")
jn2["TMP"] = jn2.L_EXTENDEDPRICE * (1 - jn2.L_DISCOUNT)
total = (
jn2.groupby(["L_ORDERKEY", "O_ORDERDATE", "O_SHIPPRIORITY"], as_index=False)[
"TMP"
]
jn2.groupby(
["L_ORDERKEY", "O_ORDERDATE", "O_SHIPPRIORITY"], as_index=False, sort=False
)["TMP"]
.sum()
.sort_values(["TMP"], ascending=False)
)
Expand All @@ -307,9 +310,9 @@ def q04(lineitem, orders):
forders = orders[osel]
jn = forders[forders["O_ORDERKEY"].isin(flineitem["L_ORDERKEY"])]
total = (
jn.groupby("O_ORDERPRIORITY", as_index=False)["O_ORDERKEY"]
.count()
.sort_values(["O_ORDERPRIORITY"])
jn.groupby("O_ORDERPRIORITY", as_index=False)["O_ORDERKEY"].count()
# skip sort when Mars enables sort in groupby
# .sort_values(["O_ORDERPRIORITY"])
)
print(total.execute())

Expand All @@ -330,7 +333,7 @@ def q05(lineitem, orders, customer, nation, region, supplier):
jn4, left_on=["S_SUPPKEY", "S_NATIONKEY"], right_on=["L_SUPPKEY", "N_NATIONKEY"]
)
jn5["TMP"] = jn5.L_EXTENDEDPRICE * (1.0 - jn5.L_DISCOUNT)
gb = jn5.groupby("N_NAME", as_index=False)["TMP"].sum()
gb = jn5.groupby("N_NAME", as_index=False, sort=False)["TMP"].sum()
total = gb.sort_values("TMP", ascending=False)
print(total.execute())

Expand Down Expand Up @@ -436,9 +439,10 @@ def q07(lineitem, supplier, orders, customer, nation):
total = total.groupby(["SUPP_NATION", "CUST_NATION", "L_YEAR"], as_index=False).agg(
REVENUE=md.NamedAgg(column="VOLUME", aggfunc="sum")
)
total = total.sort_values(
by=["SUPP_NATION", "CUST_NATION", "L_YEAR"], ascending=[True, True, True]
)
# skip sort when Mars groupby does sort already
# total = total.sort_values(
# by=["SUPP_NATION", "CUST_NATION", "L_YEAR"], ascending=[True, True, True]
# )
print(total.execute())


Expand Down Expand Up @@ -520,7 +524,7 @@ def q09(lineitem, orders, part, nation, partsupp, supplier):
(1 * jn5.PS_SUPPLYCOST) * jn5.L_QUANTITY
)
jn5["O_YEAR"] = jn5.O_ORDERDATE.dt.year
gb = jn5.groupby(["N_NAME", "O_YEAR"], as_index=False)["TMP"].sum()
gb = jn5.groupby(["N_NAME", "O_YEAR"], as_index=False, sort=False)["TMP"].sum()
total = gb.sort_values(["N_NAME", "O_YEAR"], ascending=[True, False])
print(total.execute())

Expand Down Expand Up @@ -548,6 +552,7 @@ def q10(lineitem, orders, customer, nation):
"C_COMMENT",
],
as_index=False,
sort=False,
)["TMP"].sum()
total = gb.sort_values("TMP", ascending=False)
print(total.head(20).execute())
Expand All @@ -571,7 +576,7 @@ def q11(partsupp, supplier, nation):
)
ps_supp_n_merge = ps_supp_n_merge.loc[:, ["PS_PARTKEY", "TOTAL_COST"]]
sum_val = ps_supp_n_merge["TOTAL_COST"].sum() * 0.0001
total = ps_supp_n_merge.groupby(["PS_PARTKEY"], as_index=False).agg(
total = ps_supp_n_merge.groupby(["PS_PARTKEY"], as_index=False, sort=False).agg(
VALUE=md.NamedAgg(column="TOTAL_COST", aggfunc="sum")
)
total = total[total["VALUE"] > sum_val]
Expand Down Expand Up @@ -603,7 +608,8 @@ def g2(x):

total = jn.groupby("L_SHIPMODE", as_index=False)["O_ORDERPRIORITY"].agg((g1, g2))
total = total.reset_index() # reset index to keep consistency with pandas
total = total.sort_values("L_SHIPMODE")
# skip sort when groupby does sort already
# total = total.sort_values("L_SHIPMODE")
print(total.execute())


Expand All @@ -618,10 +624,10 @@ def q13(customer, orders):
orders_filtered, left_on="C_CUSTKEY", right_on="O_CUSTKEY", how="left"
)
c_o_merged = c_o_merged.loc[:, ["C_CUSTKEY", "O_ORDERKEY"]]
count_df = c_o_merged.groupby(["C_CUSTKEY"], as_index=False).agg(
count_df = c_o_merged.groupby(["C_CUSTKEY"], as_index=False, sort=False).agg(
C_COUNT=md.NamedAgg(column="O_ORDERKEY", aggfunc="count")
)
total = count_df.groupby(["C_COUNT"], as_index=False).size()
total = count_df.groupby(["C_COUNT"], as_index=False, sort=False).size()
total.columns = ["C_COUNT", "CUSTDIST"]
total = total.sort_values(by=["CUSTDIST", "C_COUNT"], ascending=[False, False])
print(total.execute())
Expand Down Expand Up @@ -660,7 +666,7 @@ def q15(lineitem, supplier):
)
lineitem_filtered = lineitem_filtered.loc[:, ["L_SUPPKEY", "REVENUE_PARTS"]]
revenue_table = (
lineitem_filtered.groupby("L_SUPPKEY", as_index=False)
lineitem_filtered.groupby("L_SUPPKEY", as_index=False, sort=False)
.agg(TOTAL_REVENUE=md.NamedAgg(column="REVENUE_PARTS", aggfunc="sum"))
.rename(columns={"L_SUPPKEY": "SUPPLIER_NO"})
)
Expand Down Expand Up @@ -699,7 +705,7 @@ def q16(part, partsupp, supplier):
)
total = total[total["S_SUPPKEY"].isna()]
total = total.loc[:, ["P_BRAND", "P_TYPE", "P_SIZE", "PS_SUPPKEY"]]
total = total.groupby(["P_BRAND", "P_TYPE", "P_SIZE"], as_index=False)[
total = total.groupby(["P_BRAND", "P_TYPE", "P_SIZE"], as_index=False, sort=False)[
"PS_SUPPKEY"
].nunique()
total.columns = ["P_BRAND", "P_TYPE", "P_SIZE", "SUPPLIER_CNT"]
Expand All @@ -722,9 +728,9 @@ def q17(lineitem, part):
:, ["L_QUANTITY", "L_EXTENDEDPRICE", "P_PARTKEY"]
]
lineitem_filtered = lineitem.loc[:, ["L_PARTKEY", "L_QUANTITY"]]
lineitem_avg = lineitem_filtered.groupby(["L_PARTKEY"], as_index=False).agg(
avg=md.NamedAgg(column="L_QUANTITY", aggfunc="mean")
)
lineitem_avg = lineitem_filtered.groupby(
["L_PARTKEY"], as_index=False, sort=False
).agg(avg=md.NamedAgg(column="L_QUANTITY", aggfunc="mean"))
lineitem_avg["avg"] = 0.2 * lineitem_avg["avg"]
lineitem_avg = lineitem_avg.loc[:, ["L_PARTKEY", "avg"]]
total = line_part_merge.merge(
Expand All @@ -737,13 +743,14 @@ def q17(lineitem, part):

@tpc_query
def q18(lineitem, orders, customer):
gb1 = lineitem.groupby("L_ORDERKEY", as_index=False)["L_QUANTITY"].sum()
gb1 = lineitem.groupby("L_ORDERKEY", as_index=False, sort=False)["L_QUANTITY"].sum()
fgb1 = gb1[gb1.L_QUANTITY > 300]
jn1 = fgb1.merge(orders, left_on="L_ORDERKEY", right_on="O_ORDERKEY")
jn2 = jn1.merge(customer, left_on="O_CUSTKEY", right_on="C_CUSTKEY")
gb2 = jn2.groupby(
["C_NAME", "C_CUSTKEY", "O_ORDERKEY", "O_ORDERDATE", "O_TOTALPRICE"],
as_index=False,
sort=False,
)["L_QUANTITY"].sum()
total = gb2.sort_values(["O_TOTALPRICE", "O_ORDERDATE"], ascending=[False, True])
print(total.head(100).execute())
Expand Down Expand Up @@ -865,9 +872,9 @@ def q20(lineitem, part, nation, partsupp, supplier):
left_on=["PS_PARTKEY", "PS_SUPPKEY"],
right_on=["L_PARTKEY", "L_SUPPKEY"],
)
gb = jn2.groupby(["PS_PARTKEY", "PS_SUPPKEY", "PS_AVAILQTY"], as_index=False)[
"L_QUANTITY"
].sum()
gb = jn2.groupby(
["PS_PARTKEY", "PS_SUPPKEY", "PS_AVAILQTY"], as_index=False, sort=False
)["L_QUANTITY"].sum()
gbsel = gb.PS_AVAILQTY > (0.5 * gb.L_QUANTITY)
fgb = gb[gbsel]
jn3 = fgb.merge(supplier, left_on="PS_SUPPKEY", right_on="S_SUPPKEY")
Expand All @@ -886,7 +893,7 @@ def q21(lineitem, orders, supplier, nation):
# Keep all rows that have another row in linetiem with the same orderkey and different suppkey
lineitem_orderkeys = (
lineitem_filtered.loc[:, ["L_ORDERKEY", "L_SUPPKEY"]]
.groupby("L_ORDERKEY", as_index=False)["L_SUPPKEY"]
.groupby("L_ORDERKEY", as_index=False, sort=False)["L_SUPPKEY"]
.nunique()
)
lineitem_orderkeys.columns = ["L_ORDERKEY", "nunique_col"]
Expand All @@ -905,9 +912,9 @@ def q21(lineitem, orders, supplier, nation):
)

# Not Exists: Check the exists condition isn't still satisfied on the output.
lineitem_orderkeys = lineitem_filtered.groupby("L_ORDERKEY", as_index=False)[
"L_SUPPKEY"
].nunique()
lineitem_orderkeys = lineitem_filtered.groupby(
"L_ORDERKEY", as_index=False, sort=False
)["L_SUPPKEY"].nunique()
lineitem_orderkeys.columns = ["L_ORDERKEY", "nunique_col"]
lineitem_orderkeys = lineitem_orderkeys[lineitem_orderkeys["nunique_col"] == 1]
lineitem_orderkeys = lineitem_orderkeys.loc[:, ["L_ORDERKEY"]]
Expand Down Expand Up @@ -936,7 +943,7 @@ def q21(lineitem, orders, supplier, nation):
nation_filtered, left_on="S_NATIONKEY", right_on="N_NATIONKEY", how="inner"
)
total = total.loc[:, ["S_NAME"]]
total = total.groupby("S_NAME", as_index=False).size()
total = total.groupby("S_NAME", as_index=False, sort=False).size()
total.columns = ["S_NAME", "NUMWAIT"]
total = total.sort_values(by=["NUMWAIT", "S_NAME"], ascending=[False, True])
print(total.execute())
Expand Down Expand Up @@ -966,17 +973,21 @@ def q22(customer, orders):
customer_filtered, on="C_CUSTKEY", how="inner"
)
customer_selected = customer_selected.loc[:, ["CNTRYCODE", "C_ACCTBAL"]]
agg1 = customer_selected.groupby(["CNTRYCODE"], as_index=False).size()
agg1 = customer_selected.groupby(["CNTRYCODE"], as_index=False, sort=False).size()
agg1.columns = ["CNTRYCODE", "NUMCUST"]
agg2 = customer_selected.groupby(["CNTRYCODE"], as_index=False).agg(
agg2 = customer_selected.groupby(["CNTRYCODE"], as_index=False, sort=False).agg(
TOTACCTBAL=md.NamedAgg(column="C_ACCTBAL", aggfunc="sum")
)
total = agg1.merge(agg2, on="CNTRYCODE", how="inner")
total = total.sort_values(by=["CNTRYCODE"], ascending=[True])
print(total.execute())


def run_queries(data_folder: str):
def run_queries(data_folder: str, select: List[str] = None):
if select:
global queries
queries = select

# Load the data
t1 = time.time()
lineitem = load_lineitem(data_folder)
Expand Down
56 changes: 43 additions & 13 deletions mars/dataframe/groupby/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

from ... import opcodes as OperandDef
from ...config import options
from ...core.custom_log import redirect_custom_log
from ...core import ENTITY_TYPE, OutputType
from ...core.context import get_context
from ...core.custom_log import redirect_custom_log
from ...core.context import get_context, Context
from ...core.operand import OperandStage
from ...serialization.serializables import (
Int32Field,
Expand Down Expand Up @@ -65,7 +65,8 @@
cudf = lazy_import("cudf")

logger = logging.getLogger(__name__)

CV_THRESHOLD = 0.2
MEAN_RATIO_THRESHOLD = 2 / 3
_support_get_group_without_as_index = pd_release_version[:2] > (1, 0)


Expand Down Expand Up @@ -783,11 +784,36 @@ def _combine_tree(

@classmethod
def _choose_tree_method(
cls, raw_sizes, agg_sizes, sample_count, total_count, chunk_store_limit
):
cls,
raw_sizes: List[int],
agg_sizes: List[int],
sample_count: int,
total_count: int,
chunk_store_limit: int,
ctx: Context,
) -> bool:
logger.debug(
"Start to choose method for Groupby, agg_sizes: %s, raw_sizes: %s, "
"sample_count: %s, total_count: %s, chunk_store_limit: %s",
agg_sizes,
raw_sizes,
sample_count,
total_count,
chunk_store_limit,
)
estimate_size = sum(agg_sizes) / sample_count * total_count
if (
len(ctx.get_worker_addresses()) > 1
and estimate_size > chunk_store_limit
and np.mean(agg_sizes) > 1024**2
):
# for distributed, if estimate size could be potentially large,
# and each chunk size is large enough(>1M, small chunk means large error),
# we choose to use shuffle
return False
# calculate the coefficient of variation of aggregation sizes,
# if the CV is less than 0.2 and the mean of agg_size/raw_size
# is less than 0.8, we suppose the single chunk's aggregation size
# if the CV is less than CV_THRESHOLD and the mean of agg_size/raw_size
# is less than MEAN_RATIO_THRESHOLD, we suppose the single chunk's aggregation size
# almost equals to the tileable's, then use tree method
# as combine aggregation results won't lead to a rapid expansion.
ratios = [
Expand All @@ -796,12 +822,11 @@ def _choose_tree_method(
cv = variation(agg_sizes)
mean_ratio = np.mean(ratios)
if mean_ratio <= 1 / sample_count:
# if mean of ratio is less than 0.25, use tree
return True
if cv <= 0.2 and mean_ratio <= 2 / 3:
if cv <= CV_THRESHOLD and mean_ratio <= MEAN_RATIO_THRESHOLD:
# check CV and mean of ratio
return True
elif sum(agg_sizes) / sample_count * total_count <= chunk_store_limit:
if estimate_size <= chunk_store_limit:
# if estimated size less than `chunk_store_limit`, use tree.
return True
return False
Expand Down Expand Up @@ -835,9 +860,14 @@ def _tile_auto(
left_chunks = in_df.chunks[combine_size:]
left_chunks = cls._gen_map_chunks(op, left_chunks, out_df, func_infos)
if cls._choose_tree_method(
raw_sizes, agg_sizes, len(chunks), len(in_df.chunks), op.chunk_store_limit
raw_sizes,
agg_sizes,
len(chunks),
len(in_df.chunks),
op.chunk_store_limit,
ctx,
):
logger.debug("Choose tree method for groupby operand %s", op)
logger.info("Choose tree method for groupby operand %s", op)
return cls._combine_tree(op, chunks + left_chunks, out_df, func_infos)
else:
# otherwise, use shuffle
Expand All @@ -847,7 +877,7 @@ def _tile_auto(
sample_chunks = cls._sample_chunks(op, chunks + left_chunks)
pivot_chunk = cls._gen_pivot_chunk(op, sample_chunks, agg_chunk_len)

logger.debug("Choose shuffle method for groupby operand %s", op)
logger.info("Choose shuffle method for groupby operand %s", op)
return cls._perform_shuffle(
op, chunks + left_chunks, in_df, out_df, func_infos, pivot_chunk
)
Expand Down
11 changes: 3 additions & 8 deletions mars/dataframe/groupby/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,12 @@
from ... import opcodes as OperandDef
from ...core import OutputType
from ...core.operand import MapReduceOperand, OperandStage
from ...serialization.serializables import (
Int32Field,
ListField,
)
from ...utils import (
lazy_import,
)
from ...serialization.serializables import Int32Field, ListField
from ...utils import lazy_import
from ..operands import DataFrameOperandMixin
from ..sort.psrs import DataFramePSRSChunkOperand

cudf = lazy_import("cudf", globals=globals())
cudf = lazy_import("cudf")


def _series_to_df(in_series, xdf):
Expand Down
12 changes: 12 additions & 0 deletions mars/dataframe/groupby/tests/test_groupby_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,18 @@ def _disallow_combine_and_agg(ctx, op):
pd.testing.assert_frame_equal(result.sort_index(), raw.groupby("c1").agg("sum"))


def test_distributed_groupby_agg(setup_cluster):
rs = np.random.RandomState(0)
raw = pd.DataFrame(rs.rand(50000, 10))
df = md.DataFrame(raw, chunk_size=raw.shape[0] // 2)
with option_context({"chunk_store_limit": 1024**2}):
r = df.groupby(0).sum(combine_size=1)
result = r.execute().fetch()
pd.testing.assert_frame_equal(result, raw.groupby(0).sum())
# test use shuffle
assert len(r._fetch_infos()["memory_size"]) > 1


def test_groupby_agg_str_cat(setup):
agg_fun = lambda x: x.str.cat(sep="_", na_rep="NA")

Expand Down

0 comments on commit 7d78538

Please sign in to comment.