Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): support quarterly truncation #9715

Merged
merged 8 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import dask.dataframe as dd
import numpy as np
import pandas as pd
from packaging.version import parse as vparse

import ibis.backends.dask.kernels as dask_kernels
import ibis.expr.operations as ops
Expand Down Expand Up @@ -97,23 +96,6 @@ def mapper(df, cases, results, default):

return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)

@classmethod
def visit(cls, op: ops.TimestampTruncate | ops.DateTruncate, arg, unit):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method was entirely duplicated.

# TODO(kszucs): should use serieswise()
if vparse(pd.__version__) >= vparse("2.2"):
units = {"m": "min"}
else:
units = {"m": "Min", "ms": "L"}

unit = units.get(unit.short, unit.short)

if unit in "YMWD":
return arg.dt.to_period(unit).dt.to_timestamp()
try:
return arg.dt.floor(unit)
except ValueError:
return arg.dt.to_period(unit).dt.to_timestamp()

@classmethod
def visit(cls, op: ops.IntervalFromInteger, unit, **kwargs):
if unit.short in {"Y", "Q", "M", "W"}:
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def visit(cls, op: ops.TimestampTruncate | ops.DateTruncate, arg, unit):

unit = units.get(unit.short, unit.short)

if unit in "YMWD":
if unit in "YQMWD":
return arg.dt.to_period(unit).dt.to_timestamp()
try:
return arg.dt.floor(unit)
Expand Down
9 changes: 6 additions & 3 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ def visit_ExtractSecond(self, op, *, arg):
def visit_TimestampTruncate(self, op, *, arg, unit):
unit_mapping = {
"Y": "year",
"Q": "quarter",
"M": "month",
"W": "week",
"D": "day",
Expand All @@ -847,10 +848,12 @@ def visit_TimestampTruncate(self, op, *, arg, unit):
"us": "us",
}

if (unit := unit_mapping.get(unit.short)) is None:
raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}")
if (raw_unit := unit_mapping.get(unit.short)) is None:
raise com.UnsupportedOperationError(
f"Unsupported truncate unit {unit.short!r}"
)

return self.f.date_trunc(unit, arg)
return self.f.date_trunc(raw_unit, arg)

def visit_DateTruncate(self, op, *, arg, unit):
return self.visit_TimestampTruncate(op, arg=arg, unit=unit)
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit):
def visit_TimestampTruncate(self, op, *, arg, unit):
converters = {
"Y": "toStartOfYear",
"Q": "toStartOfQuarter",
"M": "toStartOfMonth",
"W": "toMonday",
"D": "toDate",
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,19 @@ def visit_LRStrip(self, op, *, arg, position):
)

def visit_DateTimestampTruncate(self, op, *, arg, unit):
if unit.short == "Q":
# adapted from https://stackoverflow.com/a/11884743
return (
# January 1 of the year of the `arg`
self.f.makedate(self.f.year(arg), 1)
# add the current quarter's number of quarters minus one to Jan 1
# first quarter: add zero
# second quarter: add one
# third quarter: add two
# fourth quarter: add three
+ sge.Interval(this=self.f.quarter(arg) - 1, unit=self.v.QUARTER)
)

truncate_formats = {
"s": "%Y-%m-%d %H:%i:%s",
"m": "%Y-%m-%d %H:%i:00",
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def visit_Xor(self, op, *, left, right):
def visit_DateTruncate(self, op, *, arg, unit):
trunc_unit_mapping = {
"Y": "year",
"Q": "Q",
"M": "MONTH",
"W": "IW",
"D": "DDD",
Expand Down
14 changes: 14 additions & 0 deletions ibis/backends/sql/compilers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,20 @@ def visit_Modulus(self, op, *, left, right):
return self.f.anon.mod(left, right)

def _temporal_truncate(self, func, arg, unit):
if unit.short == "Q":
return sge.Case(
ifs=[
self.if_(
sge.Between(
this=self.cast(self.f.strftime("%m", arg), dt.int32),
low=sge.convert(lower),
high=sge.convert(lower + 2),
),
self.f.strftime(f"%Y-{lower:0>2}-01", arg),
)
for lower in range(1, 13, 3)
],
)
modifiers = {
DateUnit.DAY: ("start of day",),
DateUnit.WEEK: ("weekday 0", "-6 days"),
Expand Down
14 changes: 13 additions & 1 deletion ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,17 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df):
),
],
),
param(
"Q",
"Q",
marks=[
pytest.mark.notimpl(
["polars"],
raises=AssertionError,
reason="numpy array are different",
),
],
),
param(
"M",
"M",
Expand Down Expand Up @@ -399,7 +410,7 @@ def test_timestamp_truncate(backend, alltypes, df, ibis_unit, pandas_unit):

dtns = df.timestamp_col.dt

if ibis_unit in ("Y", "M", "D", "W"):
if ibis_unit in ("Y", "Q", "M", "D", "W"):
expected = dtns.to_period(pandas_unit).dt.to_timestamp()
else:
expected = dtns.floor(pandas_unit)
Expand All @@ -414,6 +425,7 @@ def test_timestamp_truncate(backend, alltypes, df, ibis_unit, pandas_unit):
"unit",
[
"Y",
"Q",
"M",
"D",
param(
Expand Down
Loading