diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index a9d0f71ae9a0..ff12e594a9d9 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -942,9 +942,7 @@ def visit_DayOfWeekName(self, op, *, arg): ) def visit_IntervalFromInteger(self, op, *, arg, unit): - return sge.Interval( - this=sge.convert(arg), unit=sge.Var(this=unit.singular.upper()) - ) + return sge.Interval(this=arg, unit=self.v[unit.singular.upper()]) ### String Instruments def visit_Strip(self, op, *, arg): diff --git a/ibis/backends/sql/compilers/mssql.py b/ibis/backends/sql/compilers/mssql.py index d55a24b54c8f..a68f122f7155 100644 --- a/ibis/backends/sql/compilers/mssql.py +++ b/ibis/backends/sql/compilers/mssql.py @@ -86,14 +86,12 @@ class MSSQLCompiler(SQLGlotCompiler): ops.BitXor, ops.Covariance, ops.CountDistinctStar, - ops.DateAdd, ops.DateDiff, - ops.DateSub, ops.EndsWith, ops.IntervalAdd, - ops.IntervalFromInteger, - ops.IntervalMultiply, ops.IntervalSubtract, + ops.IntervalMultiply, + ops.IntervalFloorDivide, ops.IsInf, ops.IsNan, ops.LPad, @@ -115,9 +113,7 @@ class MSSQLCompiler(SQLGlotCompiler): ops.StringToDate, ops.StringToTimestamp, ops.StructColumn, - ops.TimestampAdd, ops.TimestampDiff, - ops.TimestampSub, ops.Unnest, ) @@ -404,6 +400,8 @@ def visit_Cast(self, op, *, arg, to): return arg elif from_.is_integer() and to.is_timestamp(): return self.f.dateadd(self.v.s, arg, "1970-01-01 00:00:00") + elif from_.is_integer() and to.is_interval(): + return sge.Interval(this=arg, unit=self.v[to.unit.singular]) return super().visit_Cast(op, arg=arg, to=to) def visit_Sum(self, op, *, arg, where): @@ -500,5 +498,18 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke return result + def visit_TimestampAdd(self, op, *, left, right): + return self.f.dateadd( + right.unit, self.cast(right.this, dt.int64), left, dialect=self.dialect + ) + + def visit_TimestampSub(self, op, *, left, right): + return self.f.dateadd( + right.unit, -self.cast(right.this, dt.int64), left, dialect=self.dialect + ) + + visit_DateAdd = visit_TimestampAdd + visit_DateSub = visit_TimestampSub + compiler = MSSQLCompiler() diff --git a/ibis/backends/sql/compilers/mysql.py b/ibis/backends/sql/compilers/mysql.py index d5367f9495a6..96584257c091 100644 --- a/ibis/backends/sql/compilers/mysql.py +++ b/ibis/backends/sql/compilers/mysql.py @@ -120,22 +120,16 @@ def visit_Cast(self, op, *, arg, to): # for TEXT (except when casting of course!) return arg elif from_.is_integer() and to.is_interval(): - return self.visit_IntervalFromInteger( - ops.IntervalFromInteger(op.arg, unit=to.unit), arg=arg, unit=to.unit - ) + return sge.Interval(this=arg, unit=self.v[to.unit.singular.upper()]) elif from_.is_integer() and to.is_timestamp(): return self.f.from_unixtime(arg) return super().visit_Cast(op, arg=arg, to=to) def visit_TimestampDiff(self, op, *, left, right): - return self.f.timestampdiff( - sge.Var(this="SECOND"), right, left, dialect=self.dialect - ) + return self.f.timestampdiff(self.v.SECOND, right, left, dialect=self.dialect) def visit_DateDiff(self, op, *, left, right): - return self.f.timestampdiff( - sge.Var(this="DAY"), right, left, dialect=self.dialect - ) + return self.f.timestampdiff(self.v.DAY, right, left, dialect=self.dialect) def visit_ApproxCountDistinct(self, op, *, arg, where): if where is not None: @@ -317,16 +311,16 @@ def visit_DateTimestampTruncate(self, op, *, arg, unit): def visit_DateTimeDelta(self, op, *, left, right, part): return self.f.timestampdiff( - sge.Var(this=part.this), right, left, dialect=self.dialect + self.v[part.this], right, left, dialect=self.dialect ) visit_TimeDelta = visit_DateDelta = visit_DateTimeDelta def visit_ExtractMillisecond(self, op, *, arg): - return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg) / 1_000) + return self.f.floor(self.f.extract(self.v.microsecond, arg) / 1_000) def visit_ExtractMicrosecond(self, op, *, arg): - return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg)) + return self.f.floor(self.f.extract(self.v.microsecond, arg)) def visit_Strip(self, op, *, arg): return self.visit_LRStrip(op, arg=arg, position="BOTH") @@ -337,14 +331,9 @@ def visit_LStrip(self, op, *, arg): def visit_RStrip(self, op, *, arg): return self.visit_LRStrip(op, arg=arg, position="TRAILING") - def visit_IntervalFromInteger(self, op, *, arg, unit): - return sge.Interval(this=arg, unit=sge.Var(this=op.resolution.upper())) - def visit_TimestampAdd(self, op, *, left, right): if op.right.dtype.unit.short == "ms": - right = sge.Interval( - this=right.this * 1_000, unit=sge.Var(this="MICROSECOND") - ) + right = sge.Interval(this=right.this * 1_000, unit=self.v.MICROSECOND) return self.f.date_add(left, right, dialect=self.dialect) def visit_UnwrapJSONString(self, op, *, arg): diff --git a/ibis/backends/sql/compilers/oracle.py b/ibis/backends/sql/compilers/oracle.py index c22a40d1dea6..18d8a56a43c0 100644 --- a/ibis/backends/sql/compilers/oracle.py +++ b/ibis/backends/sql/compilers/oracle.py @@ -67,7 +67,6 @@ class OracleCompiler(SQLGlotCompiler): ops.TimestampDelta, ops.TimestampFromYMDHMS, ops.TimeFromHMS, - ops.IntervalFromInteger, ops.DayOfWeekIndex, ops.DayOfWeekName, ops.DateDiff, @@ -138,30 +137,37 @@ def visit_Literal(self, op, *, value, dtype): elif dtype.is_uuid(): return sge.convert(str(value)) elif dtype.is_interval(): - if dtype.unit.short in ("Y", "M"): - return self.f.numtoyminterval(value, dtype.unit.name) - elif dtype.unit.short in ("D", "h", "m", "s"): - return self.f.numtodsinterval(value, dtype.unit.name) - else: - raise com.UnsupportedOperationError( - f"Intervals with precision {dtype.unit.name} not supported in Oracle." - ) + return self._value_to_interval(value, dtype.unit) return super().visit_Literal(op, value=value, dtype=dtype) + def _value_to_interval(self, arg, unit): + short = unit.short + + if short in ("Y", "M"): + return self.f.numtoyminterval(arg, unit.singular) + elif short in ("D", "h", "m", "s"): + return self.f.numtodsinterval(arg, unit.singular) + elif short == "ms": + return self.f.numtodsinterval(arg / 1e3, "second") + elif short in "us": + return self.f.numtodsinterval(arg / 1e6, "second") + elif short in "ns": + return self.f.numtodsinterval(arg / 1e9, "second") + else: + raise com.UnsupportedArgumentError( + f"Interval {unit.name} not supported by Oracle" + ) + def visit_Cast(self, op, *, arg, to): - if to.is_interval(): + from_ = op.arg.dtype + if from_.is_numeric() and to.is_interval(): # CASTing to an INTERVAL in Oracle requires specifying digits of # precision that are a pain. There are two helper functions that # should be used instead. - if to.unit.short in ("D", "h", "m", "s"): - return self.f.numtodsinterval(arg, to.unit.name) - elif to.unit.short in ("Y", "M"): - return self.f.numtoyminterval(arg, to.unit.name) - else: - raise com.UnsupportedArgumentError( - f"Interval {to.unit.name} not supported by Oracle" - ) + return self._value_to_interval(arg, to.unit) + elif from_.is_string() and to.is_date(): + return self.f.to_date(arg, "FXYYYY-MM-DD") return self.cast(arg, to) def visit_Limit(self, op, *, parent, n, offset): @@ -457,5 +463,8 @@ def visit_GroupConcat(self, op, *, arg, where, sep, order_by): return out + def visit_IntervalFromInteger(self, op, *, arg, unit): + return self._value_to_interval(arg, unit) + compiler = OracleCompiler() diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index 5074c79b3a53..7f885c37a385 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -378,9 +378,6 @@ def visit_TimestampSub(self, op, *, left, right): visit_DateAdd = visit_TimestampAdd visit_DateSub = visit_TimestampSub - def visit_IntervalFromInteger(self, op, *, arg, unit): - return sge.Interval(this=arg, unit=self.v[unit.name]) - def visit_IntegerRange(self, op, *, start, stop, step): return self.if_( step.neq(0), self.f.array_generate_range(start, stop, step), self.f.array() diff --git a/ibis/backends/tests/errors.py b/ibis/backends/tests/errors.py index b13de521b94f..c2cf994961d9 100644 --- a/ibis/backends/tests/errors.py +++ b/ibis/backends/tests/errors.py @@ -133,8 +133,9 @@ try: from oracledb.exceptions import DatabaseError as OracleDatabaseError + from oracledb.exceptions import InterfaceError as OracleInterfaceError except ImportError: - OracleDatabaseError = None + OracleDatabaseError = OracleInterfaceError = None try: from pyodbc import DataError as PyODBCDataError diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index 6ec0acfb2604..d9ea67601bb4 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -157,7 +157,7 @@ def test_scalar_param(backend, alltypes, df, value, dtype, col): ["2009-01-20", datetime.date(2009, 1, 20), datetime.datetime(2009, 1, 20)], ids=["string", "date", "datetime"], ) -@pytest.mark.notimpl(["druid", "oracle"]) +@pytest.mark.notimpl(["druid"]) def test_scalar_param_date(backend, alltypes, value): param = ibis.param("date") ds_col = alltypes.date_string_col diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 790ffecdc529..1ec79860da72 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -28,6 +28,7 @@ MySQLOperationalError, MySQLProgrammingError, OracleDatabaseError, + OracleInterfaceError, PolarsInvalidOperationError, PolarsPanicException, PsycoPg2InternalError, @@ -500,6 +501,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=com.UnsupportedOperationError, reason="month not implemented", ), + pytest.mark.notyet( + ["oracle"], + raises=OracleInterfaceError, + reason="cursor not open, probably a bug in the sql generated", + ), ], ), param( @@ -512,11 +518,8 @@ def test_date_truncate(backend, alltypes, df, unit): raises=ValueError, reason="Metadata inference failed in `add`.", ), - pytest.mark.notyet( - ["trino"], - raises=com.UnsupportedOperationError, - reason="week not implemented", - ), + pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + pytest.mark.notyet(["oracle"], raises=com.UnsupportedArgumentError), pytest.mark.notyet( ["flink"], raises=Py4JJavaError, @@ -579,8 +582,7 @@ def test_date_truncate(backend, alltypes, df, unit): ], ) @pytest.mark.notimpl( - ["datafusion", "sqlite", "mssql", "oracle", "druid", "exasol"], - raises=com.OperationNotDefinedError, + ["datafusion", "sqlite", "druid", "exasol"], raises=com.OperationNotDefinedError ) def test_integer_to_interval_timestamp( backend, con, alltypes, df, unit, displacement_type @@ -609,54 +611,63 @@ def convert_to_offset(offset, displacement_type=displacement_type): [ param( "Y", - marks=pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + marks=[ + pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + pytest.mark.notyet( + ["polars"], raises=TypeError, reason="not supported by polars" + ), + ], ), param("Q", marks=pytest.mark.xfail), param( "M", - marks=pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + marks=[ + pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + pytest.mark.notyet( + ["polars"], raises=TypeError, reason="not supported by polars" + ), + pytest.mark.notyet( + ["oracle"], + raises=OracleInterfaceError, + reason="cursor not open, probably a bug in the sql generated", + ), + ], ), param( "W", marks=[ pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + pytest.mark.notyet(["oracle"], raises=com.UnsupportedArgumentError), pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, reason="Bind error: Invalid unit: week", ), + pytest.mark.notimpl( + ["flink"], + raises=Py4JJavaError, + reason="week is not a valid unit in Flink", + ), ], ), "D", ], ) @pytest.mark.notimpl( - [ - "datafusion", - "flink", - "impala", - "mysql", - "sqlite", - "polars", - "mssql", - "druid", - "oracle", - ], - raises=com.OperationNotDefinedError, + ["datafusion", "sqlite", "druid"], raises=com.OperationNotDefinedError ) @pytest.mark.notimpl( - [ - "sqlite", - ], + ["sqlite"], raises=(com.UnsupportedOperationError, com.OperationNotDefinedError), reason="Handling unsupported op error for DateAdd with weeks", ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) def test_integer_to_interval_date(backend, con, alltypes, df, unit): interval = alltypes.int_col.to_interval(unit=unit) - array = alltypes.date_string_col.split("/") - month, day, year = array[0], array[1], array[2] - date_col = ibis.literal("-").join(["20" + year, month, day]).cast("date") + month = alltypes.date_string_col[:2] + day = alltypes.date_string_col[3:5] + year = alltypes.date_string_col[6:8] + date_col = ("20" + year + "-" + month + "-" + day).cast("date") expr = (date_col + interval).name("tmp") with warnings.catch_warnings(): @@ -708,7 +719,7 @@ def convert_to_offset(x): id="timestamp-add-interval-binop", marks=[ pytest.mark.notimpl( - ["snowflake", "sqlite", "bigquery", "exasol"], + ["snowflake", "sqlite", "bigquery", "exasol", "mssql"], raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["impala"], raises=com.UnsupportedOperationError), @@ -729,7 +740,7 @@ def convert_to_offset(x): id="timestamp-add-interval-binop-different-units", marks=[ pytest.mark.notimpl( - ["sqlite", "polars", "snowflake", "bigquery", "exasol"], + ["sqlite", "polars", "snowflake", "bigquery", "exasol", "mssql"], raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["impala"], raises=com.UnsupportedOperationError), @@ -786,7 +797,7 @@ def convert_to_offset(x): id="timestamp-subtract-timestamp", marks=[ pytest.mark.notimpl( - ["bigquery", "snowflake", "sqlite", "exasol"], + ["bigquery", "snowflake", "sqlite", "exasol", "mssql"], raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), @@ -829,7 +840,8 @@ def convert_to_offset(x): reason="DayTimeIntervalType added in pyspark 3.3", ), pytest.mark.notimpl( - ["bigquery", "druid", "flink"], raises=com.OperationNotDefinedError + ["bigquery", "druid", "flink", "mssql"], + raises=com.OperationNotDefinedError, ), pytest.mark.notyet( ["datafusion"], @@ -845,7 +857,6 @@ def convert_to_offset(x): ), ], ) -@pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError) def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): expr = expr_fn(alltypes, backend).name("tmp") expected = expected_fn(df, backend) @@ -916,7 +927,7 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): ], ) @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) -@pytest.mark.notimpl(["sqlite", "mssql"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["sqlite"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) def test_temporal_binop_pandas_timedelta( backend, con, alltypes, df, timedelta, temporal_fn @@ -1016,7 +1027,7 @@ def test_timestamp_comparison_filter_numpy(backend, con, alltypes, df, func_name backend.assert_frame_equal(result, expected) -@pytest.mark.notimpl(["mssql", "exasol", "druid"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["exasol", "druid"], raises=com.OperationNotDefinedError) def test_interval_add_cast_scalar(backend, alltypes): timestamp_date = alltypes.timestamp_col.date() delta = ibis.literal(10).cast("interval('D')") @@ -1026,7 +1037,7 @@ def test_interval_add_cast_scalar(backend, alltypes): backend.assert_series_equal(result, expected.astype(result.dtype)) -@pytest.mark.notimpl(["mssql", "exasol", "druid"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["exasol", "druid"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["flink"], raises=AssertionError, reason="incorrect results") def test_interval_add_cast_column(backend, alltypes, df): timestamp_date = alltypes.timestamp_col.date() @@ -1727,9 +1738,6 @@ def test_timestamp_column_from_ymdhms(backend, con, alltypes, df): backend.assert_series_equal(golden, result.timestamp_col) -@pytest.mark.notimpl( - ["oracle"], raises=OracleDatabaseError, reason="ORA-01861 literal does not match" -) def test_date_scalar_from_iso(con): expr = ibis.literal("2022-02-24") expr2 = ibis.date(expr) @@ -1739,11 +1747,6 @@ def test_date_scalar_from_iso(con): @pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError) -@pytest.mark.notyet( - ["oracle"], - raises=OracleDatabaseError, - reason="ORA-22849 type CLOB is not supported", -) @pytest.mark.notimpl(["exasol"], raises=AssertionError, strict=False) def test_date_column_from_iso(backend, con, alltypes, df): expr = ( @@ -1830,7 +1833,6 @@ def build_date_col(t): @pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) -@pytest.mark.notimpl(["oracle"], raises=OracleDatabaseError) @pytest.mark.parametrize( ("left_fn", "right_fn"), [ @@ -2064,8 +2066,8 @@ def test_delta(con, start, end, unit, expected): ), pytest.mark.notimpl( ["oracle"], - raises=com.UnsupportedOperationError, - reason="backend doesn't support sub-second interval precision", + raises=com.OperationNotDefinedError, + reason="TimestampBucket not implemented", ), ], id="milliseconds",