Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(decompile): make the decompiler run on TPCH query 1 #9779

Merged
merged 23 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a925189
feat(parse_sql): add support for sge.Cast and sge.DataType
gforsyth Jul 3, 2024
1d5b0de
feat(decompiler): qualify projections, sort keys, and groupbys
gforsyth Aug 6, 2024
229e0da
test(decompile): add decompile snapshot test for tpch1
gforsyth Aug 6, 2024
4519386
feat(decompile): add support for sge.Count
gforsyth Aug 6, 2024
c96bdd3
fix(decompile): wrap parens around all binary ops
gforsyth Aug 6, 2024
275f2a9
chore(decompile): rename snapshot file to avoid ruff
gforsyth Aug 6, 2024
d0173ff
feat(decompile): handle converting join filters
gforsyth Aug 6, 2024
b79c873
feat(datatypes): add mapping for typecode.DATETIME
gforsyth Aug 6, 2024
2d5d5f5
test(decompile): test tpch3 decom recom
gforsyth Aug 6, 2024
9253919
chore(snapshots): update snapshots w additional parens in binary ops
gforsyth Aug 7, 2024
46a2304
test(tpch): break out tpch results comparison into function
gforsyth Aug 7, 2024
83a8956
test(decompile_tpch): unify snapshot name and add ruff ignore
gforsyth Aug 7, 2024
242b196
fix(decompile): check that group keys exist, and merge multiple values
gforsyth Aug 7, 2024
538fb3f
fix(decompile): don't nuke join predicates if they already exist
gforsyth Aug 7, 2024
905401f
chore(decompile): update snapshots
gforsyth Aug 7, 2024
db6413d
chore(decompile): update backend sql snapshots
gforsyth Aug 7, 2024
89a55fd
fix(limit): include limits from sorts as well as scans
gforsyth Aug 8, 2024
741fe47
test(decompile): parametrize test and pull queries from sql files
gforsyth Aug 8, 2024
016dc25
test: move tpch decompile tests to separate file
gforsyth Aug 8, 2024
878b451
chore: move tests into duckdb backend tests
gforsyth Aug 8, 2024
2b0dad8
test(decompile): compare results of decompiled tpch to sql tpch
gforsyth Aug 8, 2024
df6ad4a
test(decompile): undo database switch after test
gforsyth Aug 8, 2024
2bbbd06
fix(decompile): use `_` for more robust sort key handling
gforsyth Aug 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
typecode.BOOLEAN: dt.Boolean,
typecode.CHAR: dt.String,
typecode.DATE: dt.Date,
typecode.DATETIME: dt.Timestamp,
typecode.DATE32: dt.Date,
typecode.DOUBLE: dt.Float64,
typecode.ENUM: dt.String,
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/decompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@
operator = _infix_ops[type(op)]
left = _maybe_add_parens(op.left, left)
right = _maybe_add_parens(op.right, right)
return f"{left} {operator} {right}"
return _maybe_add_parens(op, f"{left} {operator} {right}")

Check warning on line 345 in ibis/expr/decompile.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/decompile.py#L345

Added line #L345 was not covered by tests


@translate.register(ops.InValues)
Expand Down
126 changes: 121 additions & 5 deletions ibis/expr/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.sql.datatypes import SqlglotType
from ibis.util import experimental


Expand Down Expand Up @@ -95,18 +96,87 @@
return table


def qualify_sort_keys(keys, table_name):
# The sqlglot planner doesn't fully qualify sort keys
#
# - Sort: lineitem (132849388268768)
# Context:
# Key:
# - "l_returnflag"
# - "l_linestatus"
#
# For now we do a naive thing here and prepend the name of the sort
# operation itself, which (maybe?) is the name of the parent table.
table = sg.to_identifier(table_name, quoted=True)

Check warning on line 110 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L110

Added line #L110 was not covered by tests

def transformer(node):

Check warning on line 112 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L112

Added line #L112 was not covered by tests
if isinstance(node, sge.Column) and not node.table:
node.args["table"] = table
return node

Check warning on line 115 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L114-L115

Added lines #L114 - L115 were not covered by tests

sort_keys = [key.transform(transformer) for key in keys]

return sort_keys

Check warning on line 119 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L119

Added line #L119 was not covered by tests


def qualify_projections(projections, groups):
# The sqlglot planner will (sometimes) alias projections to the aggregate
# that precedes it.
#
# - Sort: lineitem (132849388268768)
# Context:
# Key:
# - "l_returnflag"
# - "l_linestatus"
# Projections:
# - lineitem._g0 AS "l_returnflag"
# - lineitem._g1 AS "l_linestatus"
# <snip>
# Dependencies:
# - Aggregate: lineitem (132849388268864)
# Context:
# Aggregations:
# <snip>
# Group:
# - "lineitem"."l_returnflag" <-- this is _g0
# - "lineitem"."l_linestatus" <-- this is _g1
# <snip>
#
# These aliases are stored in a dictionary in the aggregate `groups`, so if
# those are pulled out beforehand then we can use them to replace the
# aliases in the projections.

def transformer(node):

Check warning on line 149 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L149

Added line #L149 was not covered by tests
if isinstance(node, sge.Alias) and (name := node.this.name).startswith("_g"):
return groups[0][name]
return node

Check warning on line 152 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L151-L152

Added lines #L151 - L152 were not covered by tests

projects = [project.transform(transformer) for project in projections]

return projects

Check warning on line 156 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L156

Added line #L156 was not covered by tests


@convert.register(sgp.Sort)
def convert_sort(sort, catalog):
catalog = catalog.overlay(sort)

table = catalog[sort.name]

if sort.key:
keys = [convert(key, catalog=catalog) for key in sort.key]
keys = [
convert(key, catalog=catalog)
for key in qualify_sort_keys(sort.key, sort.name)
]
table = table.order_by(keys)

if sort.projections:
projs = [convert(proj, catalog=catalog) for proj in sort.projections]
# group definitions that may be used in projections are defined
# in the aggregate in dependencies...
groups = [val.group for val in sort.dependencies]
projs = [
convert(proj, catalog=catalog)
for proj in qualify_projections(sort.projections, groups)
]
table = table.select(projs)

return table
Expand Down Expand Up @@ -139,7 +209,8 @@
predicate = left_key == right_key
else:
predicate &= left_key == right_key
else:

if "condition" in desc.keys():
condition = desc["condition"]
predicate = convert(condition, catalog=catalog)

Expand All @@ -154,10 +225,38 @@
return left_table


def replace_operands(agg):
# The sqlglot planner will pull out computed operands into a separate
# section and alias them #
# e.g.
# Context:
# Aggregations:
# - SUM("_a_0") AS "sum_disc_price"
# Operands:
# - "lineitem"."l_extendedprice" * (1 - "lineitem"."l_discount") AS _a_0
#
# For the purposes of decompiling, we want these to be inline, so here we
# replace those new aliases with the parsed sqlglot expression
operands = {operand.alias: operand.this for operand in agg.operands}

def transformer(node):

Check warning on line 242 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L242

Added line #L242 was not covered by tests
if isinstance(node, sge.Column) and node.name in operands.keys():
return operands[node.name]
return node

Check warning on line 245 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L244-L245

Added lines #L244 - L245 were not covered by tests

aggs = [item.transform(transformer) for item in agg.aggregations]

agg.aggregations = aggs

Check warning on line 249 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L249

Added line #L249 was not covered by tests

return agg

Check warning on line 251 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L251

Added line #L251 was not covered by tests


@convert.register(sgp.Aggregate)
def convert_aggregate(agg, catalog):
catalog = catalog.overlay(agg)

agg = replace_operands(agg)

Check warning on line 258 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L258

Added line #L258 was not covered by tests

table = catalog[agg.source]
if agg.aggregations:
metrics = [convert(a, catalog=catalog) for a in agg.aggregations]
Expand Down Expand Up @@ -204,7 +303,7 @@
@convert.register(sge.Ordered)
def convert_ordered(ordered, catalog):
this = convert(ordered.this, catalog=catalog)
desc = ordered.args["desc"] # not exposed as an attribute
desc = ordered.args.get("desc", False) # not exposed as an attribute

Check warning on line 306 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L306

Added line #L306 was not covered by tests
return ibis.desc(this) if desc else ibis.asc(this)


Expand Down Expand Up @@ -258,7 +357,6 @@
sge.Quantile: "quantile",
sge.Sum: "sum",
sge.Avg: "mean",
sge.Count: "count",
}


Expand All @@ -276,6 +374,24 @@
return this.isin(candidates)


@convert.register(sge.Cast)
def cast(cast, catalog):
this = convert(cast.this, catalog)
to = convert(cast.to, catalog)

Check warning on line 380 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L379-L380

Added lines #L379 - L380 were not covered by tests

return this.cast(to)

Check warning on line 382 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L382

Added line #L382 was not covered by tests


@convert.register(sge.DataType)
def datatype(datatype, catalog):
return SqlglotType().to_ibis(datatype)

Check warning on line 387 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L387

Added line #L387 was not covered by tests


@convert.register(sge.Count)
def count(count, catalog):
return ibis._.count()

Check warning on line 392 in ibis/expr/sql.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/sql.py#L392

Added line #L392 was not covered by tests
gforsyth marked this conversation as resolved.
Show resolved Hide resolved


@public
@experimental
def parse_sql(sqlstring, catalog, dialect=None):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

import ibis

lineitem = ibis.table(
name="lineitem",
schema={
"l_orderkey": "int32",
"l_partkey": "int32",
"l_suppkey": "int32",
"l_linenumber": "int32",
"l_quantity": "decimal(15, 2)",
"l_extendedprice": "decimal(15, 2)",
"l_discount": "decimal(15, 2)",
"l_tax": "decimal(15, 2)",
"l_returnflag": "string",
"l_linestatus": "string",
"l_shipdate": "date",
"l_commitdate": "date",
"l_receiptdate": "date",
"l_shipinstruct": "string",
"l_shipmode": "string",
"l_comment": "string",
},
)
lit = ibis.literal(1)
f = lineitem.filter(lineitem.l_shipdate <= ibis.literal("1998-09-02").cast("date"))
multiply = f.l_extendedprice * (lit - f.l_discount)
agg = f.aggregate(
[
f.l_quantity.sum().name("sum_qty"),
f.l_extendedprice.sum().name("sum_base_price"),
multiply.sum().name("sum_disc_price"),
((multiply) * (lit + f.l_tax)).sum().name("sum_charge"),
f.l_quantity.mean().name("avg_qty"),
f.l_extendedprice.mean().name("avg_price"),
f.l_discount.mean().name("avg_disc"),
f.count().name("count_order"),
],
by=[f.l_returnflag, f.l_linestatus],
)

result = agg.order_by(agg.l_returnflag.asc(), agg.l_linestatus.asc())
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import ibis


lineitem = ibis.table(
name="lineitem",
schema={
"l_orderkey": "int32",
"l_partkey": "int32",
"l_suppkey": "int32",
"l_linenumber": "int32",
"l_quantity": "decimal(15, 2)",
"l_extendedprice": "decimal(15, 2)",
"l_discount": "decimal(15, 2)",
"l_tax": "decimal(15, 2)",
"l_returnflag": "string",
"l_linestatus": "string",
"l_shipdate": "date",
"l_commitdate": "date",
"l_receiptdate": "date",
"l_shipinstruct": "string",
"l_shipmode": "string",
"l_comment": "string",
},
)
lit = ibis.literal(1)
f = lineitem.filter((lineitem.l_shipdate <= ibis.literal("1998-09-02").cast("date")))
multiply = f.l_extendedprice * ((lit - f.l_discount))
agg = f.aggregate(
[
f.l_quantity.sum().name("sum_qty"),
f.l_extendedprice.sum().name("sum_base_price"),
multiply.sum().name("sum_disc_price"),
((multiply) * ((lit + f.l_tax))).sum().name("sum_charge"),
f.l_quantity.mean().name("avg_qty"),
f.l_extendedprice.mean().name("avg_price"),
f.l_discount.mean().name("avg_disc"),
f.count().name("count_order"),
],
by=[f.l_returnflag, f.l_linestatus],
)

result = agg.order_by(agg.l_returnflag.asc(), agg.l_linestatus.asc())
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import ibis


customer = ibis.table(
name="customer",
schema={
"c_custkey": "int64",
"c_name": "string",
"c_address": "string",
"c_nationkey": "int16",
"c_phone": "string",
"c_acctbal": "decimal",
"c_mktsegment": "string",
"c_comment": "string",
},
)
lit = ibis.literal(True)
orders = ibis.table(
name="orders",
schema={
"o_orderkey": "int64",
"o_custkey": "int64",
"o_orderstatus": "string",
"o_totalprice": "decimal(12, 2)",
"o_orderdate": "string",
"o_orderpriority": "string",
"o_clerk": "string",
"o_shippriority": "int32",
"o_comment": "string",
},
)
lineitem = ibis.table(
name="lineitem",
schema={
"l_orderkey": "int32",
"l_partkey": "int32",
"l_suppkey": "int32",
"l_linenumber": "int32",
"l_quantity": "decimal(15, 2)",
"l_extendedprice": "decimal(15, 2)",
"l_discount": "decimal(15, 2)",
"l_tax": "decimal(15, 2)",
"l_returnflag": "string",
"l_linestatus": "string",
"l_shipdate": "date",
"l_commitdate": "date",
"l_receiptdate": "date",
"l_shipinstruct": "string",
"l_shipmode": "string",
"l_comment": "string",
},
)
cast = ibis.literal("1995-03-15").cast("date")
joinchain = (
customer.inner_join(orders, [lit, (orders.o_orderdate.cast("timestamp") < cast)])
.inner_join(lineitem, [lit, (lineitem.l_shipdate > cast)])
.select(
customer.c_custkey,
customer.c_name,
customer.c_address,
customer.c_nationkey,
customer.c_phone,
customer.c_acctbal,
customer.c_mktsegment,
customer.c_comment,
orders.o_orderkey,
orders.o_custkey,
orders.o_orderstatus,
orders.o_totalprice,
orders.o_orderdate,
orders.o_orderpriority,
orders.o_clerk,
orders.o_shippriority,
orders.o_comment,
lineitem.l_orderkey,
lineitem.l_partkey,
lineitem.l_suppkey,
lineitem.l_linenumber,
lineitem.l_quantity,
lineitem.l_extendedprice,
lineitem.l_discount,
lineitem.l_tax,
lineitem.l_returnflag,
lineitem.l_linestatus,
lineitem.l_shipdate,
lineitem.l_commitdate,
lineitem.l_receiptdate,
lineitem.l_shipinstruct,
lineitem.l_shipmode,
lineitem.l_comment,
)
)
f = joinchain.filter((joinchain.c_mktsegment == "BUILDING"))
agg = f.aggregate(
[(f.l_extendedprice * ((1 - f.l_discount))).sum().name("revenue")],
by=[f.l_orderkey, f.o_orderdate, f.o_shippriority],
)
s = agg.order_by(agg.revenue.desc(), agg.o_orderdate.asc())

result = s.select(s.l_orderkey, s.revenue, s.o_orderdate, s.o_shippriority)
Loading
Loading