Skip to content

Commit

Permalink
feat(api): add table.nunique() for counting unique table rows
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jul 31, 2023
1 parent 7cd5835 commit adcd762
Show file tree
Hide file tree
Showing 15 changed files with 195 additions and 6 deletions.
3 changes: 2 additions & 1 deletion ci/udf/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

project(impala_test_udfs LANGUAGES CXX)
cmake_minimum_required(VERSION 3.22)

project(impala_test_udfs LANGUAGES CXX)
set(CMAKE_CXX_COMPILER clang++)

# where to put generated libraries and binaries
Expand Down
39 changes: 39 additions & 0 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any

import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import FunctionElement, GenericFunction

import ibis.common.exceptions as com
Expand Down Expand Up @@ -493,13 +494,50 @@ def _count_star(t, op):
return sa.func.count(t.translate(ops.Where(where, 1, None)))


def _count_distinct_star(t, op):
schema = op.arg.schema
cols = [sa.column(col, t.get_sqla_type(typ)) for col, typ in schema.items()]

if t._supports_tuple_syntax:
func = lambda *cols: sa.func.count(sa.distinct(sa.tuple_(*cols)))
else:
func = count_distinct

if op.where is None:
return func(*cols)

if t._has_reduction_filter_syntax:
return func(*cols).filter(t.translate(op.where))

if not t._supports_tuple_syntax and len(cols) > 1:
raise com.UnsupportedOperationError(
f"{t._dialect_name} backend doesn't support `COUNT(DISTINCT ...)` with a "
"filter with more than one column"
)

return sa.func.count(t.translate(ops.Where(op.where, sa.distinct(*cols), None)))


def _extract(fmt: str):
def translator(t, op: ops.Node):
return sa.cast(sa.extract(fmt, t.translate(op.arg)), sa.SMALLINT)

return translator


class count_distinct(FunctionElement):
inherit_cache = True


@compiles(count_distinct)
def compile_count_distinct(element, compiler, **kw):
quote_identifier = compiler.preparer.quote_identifier
clauses = ", ".join(
quote_identifier(compiler.process(clause, **kw)) for clause in element.clauses
)
return f"COUNT(DISTINCT {clauses})"


class array_map(FunctionElement):
pass

Expand All @@ -522,6 +560,7 @@ class array_filter(FunctionElement):
ops.NotContains: _contains(lambda left, right: left.notin_(right)),
ops.Count: reduction(sa.func.count),
ops.CountStar: _count_star,
ops.CountDistinctStar: _count_distinct_star,
ops.Sum: reduction(sa.func.sum),
ops.Mean: reduction(sa.func.avg),
ops.Min: reduction(sa.func.min),
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/base/sql/alchemy/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class AlchemyExprTranslator(ExprTranslator):

_bool_aggs_need_cast_to_int32 = True
_has_reduction_filter_syntax = False
_supports_tuple_syntax = False
_integer_to_timestamp = staticmethod(sa.func.to_timestamp)
_timestamp_type = sa.TIMESTAMP

Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,12 @@ def table_column(translator, op):
return quoted_name


def _count_distinct_star(t, op):
raise com.UnsupportedOperationError(
"BigQuery doesn't support COUNT(DISTINCT ...) with multiple columns"
)


OPERATION_REGISTRY = {
**operation_registry,
# Literal
Expand Down Expand Up @@ -800,6 +806,7 @@ def table_column(translator, op):
ops.StartsWith: fixed_arity("STARTS_WITH", 2),
ops.EndsWith: fixed_arity("ENDS_WITH", 2),
ops.TableColumn: table_column,
ops.CountDistinctStar: _count_distinct_star,
}

_invalid_operations = {
Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,3 +1412,12 @@ def _array_zip(op: ops.ArrayZip, **kw: Any) -> str:
sql_arg = sql_arg.sql(dialect="clickhouse")
arglist.append(sql_arg)
return f"arrayZip({', '.join(arglist)})"


@translate_val.register(ops.CountDistinctStar)
def _count_distinct_star(op: ops.CountDistinctStar, **kw: Any) -> str:
column_list = ", ".join(map(_sql, map(sg.column, op.arg.schema.names)))
if op.where is not None:
return f"countDistinctIf(({column_list}), {translate_val(op.where, **kw)})"
else:
return f"countDistinct(({column_list}))"
10 changes: 10 additions & 0 deletions ibis/backends/dask/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
execute_between,
execute_cast_series_array,
execute_cast_series_generic,
execute_count_distinct_star_frame,
execute_count_distinct_star_frame_filter,
execute_count_star_frame,
execute_count_star_frame_filter,
execute_count_star_frame_groupby,
Expand Down Expand Up @@ -107,6 +109,14 @@
((dd.DataFrame, type(None)), execute_count_star_frame),
((dd.DataFrame, dd.Series), execute_count_star_frame_filter),
],
ops.CountDistinctStar: [
(
(ddgb.DataFrameGroupBy, type(None)),
execute_count_star_frame_groupby,
),
((dd.DataFrame, type(None)), execute_count_distinct_star_frame),
((dd.DataFrame, dd.Series), execute_count_distinct_star_frame_filter),
],
ops.NullIfZero: [((dd.Series,), execute_null_if_zero_series)],
ops.Between: [
(
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class DuckDBSQLExprTranslator(AlchemyExprTranslator):
_registry = operation_registry
_rewrites = AlchemyExprTranslator._rewrites.copy()
_has_reduction_filter_syntax = True
_supports_tuple_syntax = True
_dialect_name = "duckdb"

type_mapper = DuckDBType


Expand Down
22 changes: 17 additions & 5 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,10 +756,12 @@ def execute_std_series_groupby_mask(op, data, mask, aggcontext=None, **kwargs):

@execute_node.register(ops.CountStar, DataFrameGroupBy, type(None))
def execute_count_star_frame_groupby(op, data, _, **kwargs):
result = data.size()
# FIXME(phillipc): We should not hard code this column name
result.name = 'count'
return result
return data.size()


@execute_node.register(ops.CountDistinctStar, DataFrameGroupBy, type(None))
def execute_count_distinct_star_frame_groupby(op, data, _, **kwargs):
return data.nunique()


@execute_node.register(ops.Reduction, pd.Series, (pd.Series, type(None)))
Expand Down Expand Up @@ -899,7 +901,17 @@ def execute_count_star_frame(op, data, _, **kwargs):

@execute_node.register(ops.CountStar, pd.DataFrame, pd.Series)
def execute_count_star_frame_filter(op, data, where, **kwargs):
return len(data) - (len(where) - where.sum())
return len(data) - len(where) + where.sum()


@execute_node.register(ops.CountDistinctStar, pd.DataFrame, type(None))
def execute_count_distinct_star_frame(op, data, _, **kwargs):
return len(data.drop_duplicates())


@execute_node.register(ops.CountDistinctStar, pd.DataFrame, pd.Series)
def execute_count_distinct_star_frame_filter(op, data, filt, **kwargs):
return len(data.loc[filt].drop_duplicates())


@execute_node.register(ops.BitAnd, pd.Series, (pd.Series, type(None)))
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,3 +1157,11 @@ def execute_view(op, *, ctx: pl.SQLContext, **kw):
@translate.register(ops.SelfReference)
def execute_self_reference(op, **kw):
return translate(op.table, **kw)


@translate.register(ops.CountDistinctStar)
def execute_count_distinct_star(op, **kw):
arg = pl.struct(*op.arg.schema.names)
if op.where is not None:
arg = arg.filter(translate(op.where, **kw))
return arg.n_unique()
1 change: 1 addition & 0 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class PostgreSQLExprTranslator(AlchemyExprTranslator):
_registry = operation_registry.copy()
_rewrites = AlchemyExprTranslator._rewrites.copy()
_has_reduction_filter_syntax = True
_supports_tuple_syntax = True
_dialect_name = "postgresql"

# it does support it, but we can't use it because of support for pivot
Expand Down
15 changes: 15 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,21 @@ def compile_count_star(t, op, aggcontext=None, **kwargs):
return src_table.select(col)


@compiles(ops.CountDistinctStar)
def compile_count_distinct_star(t, op, aggcontext=None, **kwargs):
src_table = t.translate(op.arg, **kwargs)
src_col = F.struct(*map(F.col, op.arg.schema.names))

if (where := op.where) is not None:
src_col = F.when(t.translate(where, **kwargs), src_col)

src_col = F.countDistinct(src_col)
if aggcontext is not None:
return src_col
else:
return src_table.select(src_col)


@compiles(ops.Max)
@compiles(ops.CumulativeMax)
def compile_max(t, op, **kwargs):
Expand Down
41 changes: 41 additions & 0 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,47 @@ def test_reduction_ops(
np.testing.assert_array_equal(result, expected)


@pytest.mark.parametrize(
("ibis_cond", "pandas_cond"),
[
param(lambda _: None, lambda _: slice(None), id="no_cond"),
param(
lambda t: t.string_col.isin(["1", "7"]),
lambda t: t.string_col.isin(["1", "7"]),
id="cond",
marks=[
mark.notyet(
["snowflake", "mysql"],
raises=com.UnsupportedOperationError,
reason="backend does not support filtered count distinct with more than one column",
),
],
),
],
)
@mark.notyet(
["bigquery", "druid", "mssql", "oracle", "sqlite"],
raises=(
sa.exc.OperationalError,
sa.exc.DatabaseError,
com.UnsupportedOperationError,
),
reason="backend doesn't support count distinct with multiple columns",
)
@mark.notyet(
["datafusion", "impala"],
raises=com.OperationNotDefinedError,
reason="no one has attempted implementation yet",
)
def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond):
table = alltypes[["int_col", "double_col", "string_col"]]
expr = table.nunique(where=ibis_cond(table))
result = expr.execute()
df = df[["int_col", "double_col", "string_col"]]
expected = len(df.loc[pandas_cond(df)].drop_duplicates())
assert int(result) == int(expected)


@pytest.mark.parametrize(
('result_fn', 'expected_fn'),
[
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class TrinoSQLExprTranslator(AlchemyExprTranslator):
_registry = operation_registry.copy()
_rewrites = AlchemyExprTranslator._rewrites.copy()
_has_reduction_filter_syntax = True
_supports_tuple_syntax = True
_integer_to_timestamp = staticmethod(sa.func.from_unixtime)

_forbids_frame_clause = (
Expand Down
7 changes: 7 additions & 0 deletions ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ class CountStar(Filterable, Reduction):
output_dtype = dt.int64


@public
class CountDistinctStar(Filterable, Reduction):
arg = rlz.table

output_dtype = dt.int64


@public
class Arbitrary(Filterable, Reduction):
arg = rlz.column(rlz.any)
Expand Down
35 changes: 35 additions & 0 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,6 +2012,41 @@ def filter(
]
return an.apply_filter(self.op(), predicates).to_expr()

def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Compute the number of unique rows in the table.
Parameters
----------
where
Optional boolean expression to filter rows when counting.
Returns
-------
IntegerScalar
Number of unique rows in the table
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"a": ["foo", "bar", "bar"]})
>>> t
┏━━━━━━━━┓
┃ a ┃
┡━━━━━━━━┩
│ string │
├────────┤
│ foo │
│ bar │
│ bar │
└────────┘
>>> t.nunique()
2
>>> t.nunique(t.a != "foo")
1
"""
return ops.CountDistinctStar(self, where=where).to_expr()

def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Compute the number of rows in the table.
Expand Down

0 comments on commit adcd762

Please sign in to comment.