diff --git a/ibis/backends/base/sql/alchemy/registry.py b/ibis/backends/base/sql/alchemy/registry.py index a143f511affc..505883f04462 100644 --- a/ibis/backends/base/sql/alchemy/registry.py +++ b/ibis/backends/base/sql/alchemy/registry.py @@ -163,18 +163,18 @@ def _exists_subquery(t, op): def _cast(t, op): arg = op.arg typ = op.to + arg_dtype = arg.output_dtype sa_arg = t.translate(arg) - sa_type = t.get_sqla_type(typ) - if isinstance(arg, ir.CategoryValue) and typ == dt.int32: + if arg_dtype.is_category() and typ.is_int32(): return sa_arg # specialize going from an integer type to a timestamp - if arg.output_dtype.is_integer() and isinstance(sa_type, sa.DateTime): + if arg_dtype.is_integer() and typ.is_timestamp(): return t.integer_to_timestamp(sa_arg) - if arg.output_dtype.is_binary() and typ.is_string(): + if arg_dtype.is_binary() and typ.is_string(): return sa.func.encode(sa_arg, 'escape') if typ.is_binary(): @@ -185,10 +185,7 @@ def _cast(t, op): if typ.is_json() and not t.native_json_type: return sa_arg - ignore_cast_types = t._ignore_cast_types - if ignore_cast_types and isinstance(typ, ignore_cast_types): - return sa_arg - return sa.cast(sa_arg, sa_type) + return t.cast(sa_arg, typ) def _contains(func): diff --git a/ibis/backends/base/sql/alchemy/translator.py b/ibis/backends/base/sql/alchemy/translator.py index 62f22803bb8f..0fee7cde6126 100644 --- a/ibis/backends/base/sql/alchemy/translator.py +++ b/ibis/backends/base/sql/alchemy/translator.py @@ -45,7 +45,6 @@ class AlchemyExprTranslator(ExprTranslator): integer_to_timestamp = sa.func.to_timestamp native_json_type = True _always_quote_columns = False - _ignore_cast_types = () _require_order_by = ( ops.DenseRank, @@ -89,6 +88,9 @@ def _reduction(self, sa_func, op): return sa_func(*sa_args) + def cast(self, sa_expr, ibis_type: dt.DataType): + return sa.cast(sa_expr, self.get_sqla_type(ibis_type)) + rewrites = AlchemyExprTranslator.rewrites diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 454505af0917..d6acf2d4fde3 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -40,7 +40,11 @@ class SnowflakeExprTranslator(AlchemyExprTranslator): ops.Lead, ) _require_order_by = (*AlchemyExprTranslator._require_order_by, ops.Reduction) - _ignore_cast_types = (dt.Map, dt.Array) + + def cast(self, sa_expr, ibis_type: dt.DataType): + if ibis_type.is_array() or ibis_type.is_map() or ibis_type.is_struct(): + return sa.type_coerce(sa_expr, self.get_sqla_type(ibis_type)) + return super().cast(sa_expr, ibis_type) class SnowflakeCompiler(AlchemyCompiler):