diff --git a/ibis/backends/base/sql/alchemy/registry.py b/ibis/backends/base/sql/alchemy/registry.py index 241ad8863809..e26d12de4085 100644 --- a/ibis/backends/base/sql/alchemy/registry.py +++ b/ibis/backends/base/sql/alchemy/registry.py @@ -482,6 +482,13 @@ def _count_star(t, op): return sa.func.count(t.translate(ops.Where(where, 1, None))) +def _extract(fmt: str): + def translator(t, op: ops.Node): + return sa.cast(sa.extract(fmt, t.translate(op.arg)), sa.SMALLINT) + + return translator + + sqlalchemy_operation_registry: Dict[Any, Any] = { ops.Alias: _alias, ops.And: fixed_arity(operator.and_, 2), @@ -612,6 +619,13 @@ def _count_star(t, op): ops.BitwiseRightShift: _bitwise_op(">>"), ops.BitwiseNot: _bitwise_not, ops.JSONGetItem: fixed_arity(lambda x, y: x.op("->")(y), 2), + ops.ExtractYear: _extract('year'), + ops.ExtractQuarter: _extract('quarter'), + ops.ExtractMonth: _extract('month'), + ops.ExtractDay: _extract('day'), + ops.ExtractHour: _extract('hour'), + ops.ExtractMinute: _extract('minute'), + ops.ExtractSecond: _extract('second'), } diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index 2b1b64112746..439bfc77fab8 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -27,6 +27,7 @@ from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported from ibis.backends.base.sql.alchemy.registry import ( _bitwise_op, + _extract, geospatial_functions, get_col_or_deferred_col, ) @@ -38,19 +39,16 @@ operation_registry.update(geospatial_functions) -def _extract(fmt, output_type=sa.SMALLINT): - def translator(t, op, output_type=output_type): - sa_arg = t.translate(op.arg) - return sa.cast(sa.extract(fmt, sa_arg), output_type) - - return translator +def _epoch(t, op): + sa_arg = t.translate(op.arg) + return sa.cast(sa.extract('epoch', sa_arg), sa.INTEGER) def _second(t, op): # extracting the second gives us the fractional part as well, so smash that # with a cast to SMALLINT sa_arg = t.translate(op.arg) - return sa.cast(sa.func.FLOOR(sa.extract('second', sa_arg)), sa.SMALLINT) + return sa.cast(sa.func.floor(sa.extract('second', sa_arg)), sa.SMALLINT) def _millisecond(t, op): @@ -570,15 +568,9 @@ def _array_collect(t, op): ops.TimestampSub: fixed_arity(operator.sub, 2), ops.TimestampDiff: fixed_arity(operator.sub, 2), ops.Strftime: _strftime, - ops.ExtractYear: _extract('year'), - ops.ExtractMonth: _extract('month'), - ops.ExtractDay: _extract('day'), + ops.ExtractEpochSeconds: _epoch, ops.ExtractDayOfYear: _extract('doy'), - ops.ExtractQuarter: _extract('quarter'), - ops.ExtractEpochSeconds: _extract('epoch', sa.Integer), ops.ExtractWeekOfYear: _extract('week'), - ops.ExtractHour: _extract('hour'), - ops.ExtractMinute: _extract('minute'), ops.ExtractSecond: _second, ops.ExtractMillisecond: _millisecond, ops.DayOfWeekIndex: _day_of_week_index, diff --git a/ibis/backends/tests/test_column.py b/ibis/backends/tests/test_column.py index b7953bcf72b9..fd58427d74cf 100644 --- a/ibis/backends/tests/test_column.py +++ b/ibis/backends/tests/test_column.py @@ -83,7 +83,7 @@ def test_distinct_column(alltypes, df, column): ("day", set(range(1, 32))), ], ) -@pytest.mark.notimpl(["datafusion", "trino"]) +@pytest.mark.notimpl(["datafusion"]) @pytest.mark.notyet(["impala"]) def test_date_extract_field(con, opname, expected): op = operator.methodcaller(opname) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index bc7a0159bf3c..e47a1a1fb81c 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -27,7 +27,7 @@ ), ], ) -@pytest.mark.notimpl(["datafusion", "trino"]) +@pytest.mark.notimpl(["datafusion"]) def test_date_extract(backend, alltypes, df, attr, expr_fn): expr = getattr(expr_fn(alltypes.timestamp_col), attr)() expected = getattr(df.timestamp_col.dt, attr).astype('int32') @@ -44,7 +44,8 @@ def test_date_extract(backend, alltypes, df, attr, expr_fn): 'month', 'day', param( - 'day_of_year', marks=pytest.mark.notimpl(["bigquery", "impala", "mssql"]) + 'day_of_year', + marks=pytest.mark.notimpl(["bigquery", "impala", "mssql", "trino"]), ), param('quarter', marks=pytest.mark.notimpl(["mssql"])), 'hour', @@ -52,7 +53,7 @@ def test_date_extract(backend, alltypes, df, attr, expr_fn): 'second', ], ) -@pytest.mark.notimpl(["datafusion", "trino"]) +@pytest.mark.notimpl(["datafusion"]) def test_timestamp_extract(backend, alltypes, df, attr): method = getattr(alltypes.timestamp_col, attr) expr = method().name(attr) @@ -77,7 +78,7 @@ def test_timestamp_extract(backend, alltypes, df, attr): 359, id='millisecond', marks=[ - pytest.mark.notimpl(["clickhouse", "pyspark"]), + pytest.mark.notimpl(["clickhouse", "pyspark", "trino"]), pytest.mark.broken( ["mysql"], reason="MySQL implementation of milliseconds is broken", @@ -88,17 +89,17 @@ def test_timestamp_extract(backend, alltypes, df, attr): lambda x: x.day_of_week.index(), 1, id='day_of_week_index', - marks=pytest.mark.notimpl(["mssql"]), + marks=pytest.mark.notimpl(["mssql", "trino"]), ), param( lambda x: x.day_of_week.full_name(), 'Tuesday', id='day_of_week_full_name', - marks=pytest.mark.notimpl(["mssql"]), + marks=pytest.mark.notimpl(["mssql", "trino"]), ), ], ) -@pytest.mark.notimpl(["datafusion", "snowflake", "trino"]) +@pytest.mark.notimpl(["datafusion", "snowflake"]) def test_timestamp_extract_literal(con, func, expected): value = ibis.timestamp('2015-09-01 14:48:05.359') assert con.execute(func(value).name("tmp")) == expected