diff --git a/ibis/backends/clickhouse/compiler/values.py b/ibis/backends/clickhouse/compiler/values.py index 68bcf9473b35..2faf121e310d 100644 --- a/ibis/backends/clickhouse/compiler/values.py +++ b/ibis/backends/clickhouse/compiler/values.py @@ -469,43 +469,38 @@ def _literal(op, **kw): elif micros // 1000: args.append(3) - if (timezone := op.output_dtype.timezone) is not None: + if (timezone := dtype.timezone) is not None: args.append(timezone) joined_args = ", ".join(map(repr, args)) return f"{func}({joined_args})" - elif isinstance(op.output_dtype, dt.Date): + elif dtype.is_date(): formatted = value.strftime('%Y-%m-%d') return f"toDate('{formatted}')" - elif isinstance(op.output_dtype, dt.Array): - values = ", ".join(_array_literal_values(op)) + elif dtype.is_array(): + value_type = dtype.value_type + values = ", ".join( + _literal(ops.Literal(v, dtype=value_type), **kw) for v in value + ) return f"[{values}]" - elif isinstance(op.output_dtype, dt.Map): - values = ", ".join(_map_literal_values(op)) + elif dtype.is_map(): + value_type = dtype.value_type + values = ", ".join( + f"{k!r}, {_literal(ops.Literal(v, dtype=value_type), **kw)}" + for k, v in value.items() + ) return f"map({values})" - elif isinstance(op.output_dtype, dt.Struct): - fields = ", ".join(f"{value} as `{key}`" for key, value in op.value.items()) + elif dtype.is_struct(): + fields = ", ".join( + _literal(ops.Literal(v, dtype=subdtype), **kw) + for subdtype, v in zip(dtype.types, value.values()) + ) return f"tuple({fields})" else: raise NotImplementedError(f'Unsupported type: {dtype!r}') -def _array_literal_values(op): - value_type = op.output_dtype.value_type - for v in op.value: - value = ops.Literal(v, dtype=value_type) - yield _literal(value) - - -def _map_literal_values(op): - value_type = op.output_dtype.value_type - for k, v in op.value.items(): - value = ops.Literal(v, dtype=value_type) - yield repr(k) - yield _literal(value) - - def _sql(obj, dialect="clickhouse"): try: return obj.sql(dialect=dialect) @@ -757,7 +752,7 @@ def _clip(op, **kw): def _struct_field(op, **kw): arg = op.arg arg_dtype = arg.output_dtype - arg = translate_val(op.arg, **kw) + arg = translate_val(op.arg, render_aliases=False, **kw) idx = arg_dtype.names.index(op.field) typ = arg_dtype.types[idx] return f"CAST({arg}.{idx + 1} AS {serialize(typ)})" @@ -1295,39 +1290,39 @@ def _rank(_, **kw): @translate_val.register(ops.ExtractProtocol) def _extract_protocol(op, **kw): - arg = translate_val(op.arg, **kw) + arg = translate_val(op.arg, render_aliases=False, **kw) return f"nullIf(protocol({arg}), '')" @translate_val.register(ops.ExtractAuthority) def _extract_authority(op, **kw): - arg = translate_val(op.arg, **kw) + arg = translate_val(op.arg, render_aliases=False, **kw) return f"nullIf(netloc({arg}), '')" @translate_val.register(ops.ExtractHost) def _extract_host(op, **kw): - arg = translate_val(op.arg, **kw) + arg = translate_val(op.arg, render_aliases=False, **kw) return f"nullIf(domain({arg}), '')" @translate_val.register(ops.ExtractFile) def _extract_file(op, **kw): - arg = translate_val(op.arg, **kw) + arg = translate_val(op.arg, render_aliases=False, **kw) return f"nullIf(cutFragment(pathFull({arg})), '')" @translate_val.register(ops.ExtractPath) def _extract_path(op, **kw): - arg = translate_val(op.arg, **kw) + arg = translate_val(op.arg, render_aliases=False, **kw) return f"nullIf(path({arg}), '')" @translate_val.register(ops.ExtractQuery) def _extract_query(op, **kw): - arg = translate_val(op.arg, **kw) + arg = translate_val(op.arg, render_aliases=False, **kw) if (key := op.key) is not None: - key = translate_val(key, **kw) + key = translate_val(key, render_aliases=False, **kw) return f"nullIf(extractURLParameter({arg}, {key}), '')" else: return f"nullIf(queryString({arg}), '')" @@ -1335,7 +1330,7 @@ def _extract_query(op, **kw): @translate_val.register(ops.ExtractFragment) def _extract_fragment(op, **kw): - arg = translate_val(op.arg, **kw) + arg = translate_val(op.arg, render_aliases=False, **kw) return f"nullIf(fragment({arg}), '')" diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index f099166f38a6..1bb4482ca1c0 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -78,9 +78,7 @@ def test_scalar_param_array(con): assert result == len(value) -@pytest.mark.notimpl( - ["clickhouse", "datafusion", "impala", "postgres", "pyspark", "druid", "oracle"] -) +@pytest.mark.notimpl(["datafusion", "impala", "postgres", "pyspark", "druid", "oracle"]) @pytest.mark.never( ["mysql", "sqlite", "mssql"], reason="mysql and sqlite will never implement struct types",