Skip to content

Commit

Permalink
fix(clickhouse): avoid generating names for structs
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed May 25, 2023
1 parent 9c40c06 commit 5d11f48
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 35 deletions.
59 changes: 27 additions & 32 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,43 +469,38 @@ def _literal(op, **kw):
elif micros // 1000:
args.append(3)

if (timezone := op.output_dtype.timezone) is not None:
if (timezone := dtype.timezone) is not None:
args.append(timezone)

joined_args = ", ".join(map(repr, args))
return f"{func}({joined_args})"

elif isinstance(op.output_dtype, dt.Date):
elif dtype.is_date():
formatted = value.strftime('%Y-%m-%d')
return f"toDate('{formatted}')"
elif isinstance(op.output_dtype, dt.Array):
values = ", ".join(_array_literal_values(op))
elif dtype.is_array():
value_type = dtype.value_type
values = ", ".join(
_literal(ops.Literal(v, dtype=value_type), **kw) for v in value
)
return f"[{values}]"
elif isinstance(op.output_dtype, dt.Map):
values = ", ".join(_map_literal_values(op))
elif dtype.is_map():
value_type = dtype.value_type
values = ", ".join(
f"{k!r}, {_literal(ops.Literal(v, dtype=value_type), **kw)}"
for k, v in value.items()
)
return f"map({values})"
elif isinstance(op.output_dtype, dt.Struct):
fields = ", ".join(f"{value} as `{key}`" for key, value in op.value.items())
elif dtype.is_struct():
fields = ", ".join(
_literal(ops.Literal(v, dtype=subdtype), **kw)
for subdtype, v in zip(dtype.types, value.values())
)
return f"tuple({fields})"
else:
raise NotImplementedError(f'Unsupported type: {dtype!r}')


def _array_literal_values(op):
value_type = op.output_dtype.value_type
for v in op.value:
value = ops.Literal(v, dtype=value_type)
yield _literal(value)


def _map_literal_values(op):
value_type = op.output_dtype.value_type
for k, v in op.value.items():
value = ops.Literal(v, dtype=value_type)
yield repr(k)
yield _literal(value)


def _sql(obj, dialect="clickhouse"):
try:
return obj.sql(dialect=dialect)
Expand Down Expand Up @@ -757,7 +752,7 @@ def _clip(op, **kw):
def _struct_field(op, **kw):
arg = op.arg
arg_dtype = arg.output_dtype
arg = translate_val(op.arg, **kw)
arg = translate_val(op.arg, render_aliases=False, **kw)
idx = arg_dtype.names.index(op.field)
typ = arg_dtype.types[idx]
return f"CAST({arg}.{idx + 1} AS {serialize(typ)})"
Expand Down Expand Up @@ -1295,47 +1290,47 @@ def _rank(_, **kw):

@translate_val.register(ops.ExtractProtocol)
def _extract_protocol(op, **kw):
arg = translate_val(op.arg, **kw)
arg = translate_val(op.arg, render_aliases=False, **kw)
return f"nullIf(protocol({arg}), '')"


@translate_val.register(ops.ExtractAuthority)
def _extract_authority(op, **kw):
arg = translate_val(op.arg, **kw)
arg = translate_val(op.arg, render_aliases=False, **kw)
return f"nullIf(netloc({arg}), '')"


@translate_val.register(ops.ExtractHost)
def _extract_host(op, **kw):
arg = translate_val(op.arg, **kw)
arg = translate_val(op.arg, render_aliases=False, **kw)
return f"nullIf(domain({arg}), '')"


@translate_val.register(ops.ExtractFile)
def _extract_file(op, **kw):
arg = translate_val(op.arg, **kw)
arg = translate_val(op.arg, render_aliases=False, **kw)
return f"nullIf(cutFragment(pathFull({arg})), '')"


@translate_val.register(ops.ExtractPath)
def _extract_path(op, **kw):
arg = translate_val(op.arg, **kw)
arg = translate_val(op.arg, render_aliases=False, **kw)
return f"nullIf(path({arg}), '')"


@translate_val.register(ops.ExtractQuery)
def _extract_query(op, **kw):
arg = translate_val(op.arg, **kw)
arg = translate_val(op.arg, render_aliases=False, **kw)
if (key := op.key) is not None:
key = translate_val(key, **kw)
key = translate_val(key, render_aliases=False, **kw)
return f"nullIf(extractURLParameter({arg}, {key}), '')"
else:
return f"nullIf(queryString({arg}), '')"


@translate_val.register(ops.ExtractFragment)
def _extract_fragment(op, **kw):
arg = translate_val(op.arg, **kw)
arg = translate_val(op.arg, render_aliases=False, **kw)
return f"nullIf(fragment({arg}), '')"


Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def test_scalar_param_array(con):
assert result == len(value)


@pytest.mark.notimpl(
["clickhouse", "datafusion", "impala", "postgres", "pyspark", "druid", "oracle"]
)
@pytest.mark.notimpl(["datafusion", "impala", "postgres", "pyspark", "druid", "oracle"])
@pytest.mark.never(
["mysql", "sqlite", "mssql"],
reason="mysql and sqlite will never implement struct types",
Expand Down

0 comments on commit 5d11f48

Please sign in to comment.