Skip to content

Commit

Permalink
[FEAT] Changes the default count() behavior to perform a global row c…
Browse files Browse the repository at this point in the history
…ount instead (#2653)

Closes: #1996 

Our new `df.count()` behavior will be similar to `SELECT COUNT(*)` in
SQL, returning a new dataframe with a single `"count"` column.

---------

Co-authored-by: Jay Chia <jaychia94@gmail.com@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Aug 15, 2024
1 parent 0ddc361 commit b961ad3
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
37 changes: 37 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,11 +1969,48 @@ def any_value(self, *cols: ColumnInputType) -> "DataFrame":
def count(self, *cols: ColumnInputType) -> "DataFrame":
"""Performs a global count on the DataFrame
If no columns are specified (i.e. in the case you call `df.count()`) this functions very
similarly to a COUNT(*) operation in SQL and will return a new dataframe with a single column
with the name "count".
>>> import daft
>>> df = daft.from_pydict({"foo": [1, None, None], "bar": [None, 2, 2]})
>>> df.count().show()
╭────────╮
│ count │
│ --- │
│ UInt64 │
╞════════╡
│ 3 │
╰────────╯
<BLANKLINE>
(Showing first 1 of 1 rows)
However, specifying some column names would instead change the behavior to count all non-null values,
similar to a SQL command for `SELECT COUNT(foo), COUNT(bar) FROM df`
>>> df.count("foo", "bar").show()
╭────────┬────────╮
│ foo ┆ bar │
│ --- ┆ --- │
│ UInt64 ┆ UInt64 │
╞════════╪════════╡
│ 1 ┆ 2 │
╰────────┴────────╯
<BLANKLINE>
(Showing first 1 of 1 rows)
Args:
*cols (Union[str, Expression]): columns to count
Returns:
DataFrame: Globally aggregated count. Should be a single row.
"""
# Special case: treat this as a COUNT(*) operation which is likely what most people would expect
if len(cols) == 0:
builder = self._builder.count()
return DataFrame(builder)

# Otherwise, perform a column-wise count on the specified columns
return self._apply_agg_fn(Expression.count, cols)

@DataframePublicAPI
Expand Down
8 changes: 8 additions & 0 deletions tests/cookbook/test_count_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ def test_count_rows(daft_df, service_requests_csv_pd_df, repartition_nparts):
assert daft_df_row_count == service_requests_csv_pd_df.shape[0]


def test_dataframe_count_no_args(daft_df, service_requests_csv_pd_df):
"""Counts rows using `df.count()` without any arguments"""
results = daft_df.count().to_pydict()
assert "count" in results
assert len(results["count"]) == 1
assert results["count"][0] == service_requests_csv_pd_df.shape[0]


def test_filtered_count_rows(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""Count rows on a table filtered by a certain condition"""
daft_df_row_count = daft_df.repartition(repartition_nparts).where(col("Borough") == "BROOKLYN").count_rows()
Expand Down
12 changes: 6 additions & 6 deletions tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ def test_parquet_helper(data_and_type, use_daft_writer):
after.show(10)
after = after.sort(col("_index"))
assert before.to_pydict() == after.to_pydict()
assert [x for x in before.explode(col("nested_col")).count().collect()] == [
x for x in after.explode(col("nested_col")).count().collect()
assert [x for x in before.explode(col("nested_col")).count(*before.column_names).collect()] == [
x for x in after.explode(col("nested_col")).count(*after.column_names).collect()
]
before = before.limit(50)
after = after.limit(50)
assert before.to_pydict() == after.to_pydict()
assert [x for x in before.explode(col("nested_col")).count().collect()] == [
x for x in after.explode(col("nested_col")).count().collect()
assert [x for x in before.explode(col("nested_col")).count(*before.column_names).collect()] == [
x for x in after.explode(col("nested_col")).count(*after.column_names).collect()
]

# Test Arrow write with Daft read.
Expand All @@ -242,14 +242,14 @@ def test_parquet_helper(data_and_type, use_daft_writer):
assert before.to_pydict() == after.to_pydict()
pd_table = before.to_pandas().explode("nested_col")
assert [pd_table.count().get("nested_col")] == [
x["nested_col"] for x in after.explode(col("nested_col")).count().collect()
x["nested_col"] for x in after.explode(col("nested_col")).count(*after.column_names).collect()
]
before = before.take(list(range(min(before.num_rows, 50))))
after = after.limit(50)
assert before.to_pydict() == after.to_pydict()
pd_table = before.to_pandas().explode("nested_col")
assert [pd_table.count().get("nested_col")] == [
x["nested_col"] for x in after.explode(col("nested_col")).count().collect()
x["nested_col"] for x in after.explode(col("nested_col")).count(*after.column_names).collect()
]

# The normal case where the last row `nested.field1` is contained within a single data page.
Expand Down

0 comments on commit b961ad3

Please sign in to comment.