Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Nov 7, 2024
1 parent 27295f3 commit f8a22c9
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 44 deletions.
86 changes: 45 additions & 41 deletions daft/sql/sql_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,54 +185,58 @@ def _get_partition_bounds(self, num_scan_tasks: int) -> list[Any]:
)

if self._partition_bound_strategy == PartitionBoundStrategy.PERCENTILE:
# Try to get percentiles using percentile_disc.
# Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons.
percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)]
# Use the OVER clause for SQL Server
over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else ""
percentile_sql = self.conn.construct_sql_query(
self.sql,
projection=[
f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}"
for i, percentile in enumerate(percentiles)
],
limit=1,
)
pa_table = self.conn.execute_sql_query(percentile_sql)

if pa_table.num_rows != 1:
raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.")

if pa_table.num_columns != num_scan_tasks + 1:
raise RuntimeError(
f"Failed to get partition bounds: expected {num_scan_tasks + 1} percentiles, but got {pa_table.num_columns}."
try:
# Try to get percentiles using percentile_disc.
# Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons.
percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)]
# Use the OVER clause for SQL Server dialects
over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else ""
percentile_sql = self.conn.construct_sql_query(
self.sql,
projection=[
f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}"
for i, percentile in enumerate(percentiles)
],
limit=1,
)
pa_table = self.conn.execute_sql_query(percentile_sql)

pydict = Table.from_arrow(pa_table).to_pydict()
assert pydict.keys() == {f"bound_{i}" for i in range(num_scan_tasks + 1)}
bounds = [pydict[f"bound_{i}"][0] for i in range(num_scan_tasks + 1)]
if pa_table.num_rows != 1:
raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.")

elif self._partition_bound_strategy == PartitionBoundStrategy.MIN_MAX:
min_max_sql = self.conn.construct_sql_query(
self.sql, projection=[f"MIN({self._partition_col}) as min", f"MAX({self._partition_col}) as max"]
)
pa_table = self.conn.execute_sql_query(min_max_sql)
if pa_table.num_columns != num_scan_tasks + 1:
raise RuntimeError(
f"Failed to get partition bounds: expected {num_scan_tasks + 1} percentiles, but got {pa_table.num_columns}."
)

pydict = Table.from_arrow(pa_table).to_pydict()
assert pydict.keys() == {f"bound_{i}" for i in range(num_scan_tasks + 1)}
return [pydict[f"bound_{i}"][0] for i in range(num_scan_tasks + 1)]

if pa_table.num_rows != 1:
raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.")
if pa_table.num_columns != 2:
raise RuntimeError(
f"Failed to get partition bounds: expected 2 columns, but got {pa_table.num_columns}."
except Exception as e:
warnings.warn(
f"Failed to calculate partition bounds for read_sql using percentile strategy: {str(e)}. "
"Falling back to MIN_MAX strategy."
)
self._partition_bound_strategy = PartitionBoundStrategy.MIN_MAX

pydict = Table.from_arrow(pa_table).to_pydict()
assert pydict.keys() == {"min", "max"}
min_val = pydict["min"][0]
max_val = pydict["max"][0]
range_size = (max_val - min_val) / num_scan_tasks
bounds = [min_val + range_size * i for i in range(num_scan_tasks)] + [max_val]
# Either MIN_MAX was explicitly specified or percentile calculation failed
min_max_sql = self.conn.construct_sql_query(
self.sql, projection=[f"MIN({self._partition_col}) as min", f"MAX({self._partition_col}) as max"]
)
pa_table = self.conn.execute_sql_query(min_max_sql)

return bounds
if pa_table.num_rows != 1:
raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.")
if pa_table.num_columns != 2:
raise RuntimeError(f"Failed to get partition bounds: expected 2 columns, but got {pa_table.num_columns}.")

pydict = Table.from_arrow(pa_table).to_pydict()
assert pydict.keys() == {"min", "max"}
min_val = pydict["min"][0]
max_val = pydict["max"][0]
range_size = (max_val - min_val) / num_scan_tasks
return [min_val + range_size * i for i in range(num_scan_tasks)] + [max_val]

def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]:
return iter([self._construct_scan_task(pushdowns, num_rows=total_rows, size_bytes=math.ceil(total_size))])
Expand Down
14 changes: 11 additions & 3 deletions tests/integration/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,31 @@ def test_sql_create_dataframe_ok(test_db, pdf) -> None:

@pytest.mark.integration()
@pytest.mark.parametrize("num_partitions", [2, 3, 4])
def test_sql_partitioned_read(test_db, num_partitions, pdf) -> None:
@pytest.mark.parametrize("partition_bound_strategy", ["min-max", "percentile"])
def test_sql_partitioned_read(test_db, num_partitions, partition_bound_strategy, pdf) -> None:
row_size_bytes = daft.from_pandas(pdf).schema().estimate_row_size_bytes()
num_rows_per_partition = len(pdf) / num_partitions
with daft.execution_config_ctx(
read_sql_partition_size_bytes=math.ceil(row_size_bytes * num_rows_per_partition),
scan_tasks_min_size_bytes=0,
scan_tasks_max_size_bytes=0,
):
df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id")
df = daft.read_sql(
f"SELECT * FROM {TEST_TABLE_NAME}",
test_db,
partition_col="id",
partition_bound_strategy=partition_bound_strategy,
)
assert df.num_partitions() == num_partitions
assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id")


@pytest.mark.integration()
@pytest.mark.parametrize("num_partitions", [1, 2, 3, 4])
@pytest.mark.parametrize("partition_col", ["id", "float_col", "date_col", "date_time_col"])
@pytest.mark.parametrize("partition_bound_strategy", ["min-max", "percentile"])
def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col(
test_db, num_partitions, partition_col, pdf
test_db, num_partitions, partition_col, partition_bound_strategy, pdf
) -> None:
with daft.execution_config_ctx(
scan_tasks_min_size_bytes=0,
Expand All @@ -60,6 +67,7 @@ def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col(
test_db,
partition_col=partition_col,
num_partitions=num_partitions,
partition_bound_strategy=partition_bound_strategy,
)
assert df.num_partitions() == num_partitions
assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id")
Expand Down

0 comments on commit f8a22c9

Please sign in to comment.