Skip to content

Commit

Permalink
feat(decompile): make the decompiler run on TPCH query 1 (#9779)
Browse files Browse the repository at this point in the history
Context: the sql -> ibis expression parser/creator runs off of a
`sqlglot.Plan` object

The `sqlglot` plan optimizer replaces a few items with aliases. In other
places, it doesn't use
fully-qualified column names. Both of these prevent us from reliably
converting into an Ibis expression.

## Description of changes

### Dereferencing aggregation operands

The sqlglot planner will pull out computed operands into a separate
section and alias them e.g.
```
Context:
    Aggregations:
    - SUM("_a_0") AS "sum_disc_price"
    Operands:
    - "lineitem"."l_extendedprice" * (1 - "lineitem"."l_discount") AS _a_0
```

For the purposes of decompiling, we want these to be inline, so here we
replace those new aliases with the parsed sqlglot expression

### Dereferencing projections

The sqlglot planner will (sometimes) alias projections to the aggregate
that precedes it.

```
- Sort: lineitem (132849388268768)
  Context:
    Key:
      - "l_returnflag"
      - "l_linestatus"
  Projections:
    - lineitem._g0 AS "l_returnflag"
    - lineitem._g1 AS "l_linestatus"
    <snip>
  Dependencies:
  - Aggregate: lineitem (132849388268864)
    Context:
      Aggregations:
        <snip>
      Group:
        - "lineitem"."l_returnflag"  <-- this is _g0
        - "lineitem"."l_linestatus"  <-- this is _g1
        <snip>
```

These aliases are stored in a dictionary in the aggregate `groups`, so
if
those are pulled out beforehand then we can use them to replace the
aliases in the projections.

### Dereferencing sort keys

The sqlglot planner doesn't fully qualify sort keys

```
- Sort: lineitem (132849388268768)
  Context:
    Key:
      - "l_returnflag"
      - "l_linestatus"
```

~For now we do a naive thing here and prepend the name of the sort~
~operation itself, which (maybe?) is the name of the parent table.~
~*This is definitely the most brittle part of the existing decompiler*~
 
I've mucked around with this a little bit, and while it's a _touch_
hacky, using the deferred operator with the not-fully-referenced
sort-keys works well.
  • Loading branch information
gforsyth authored Aug 26, 2024
1 parent ea97794 commit 0268044
Show file tree
Hide file tree
Showing 49 changed files with 494 additions and 102 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import ibis


lineitem = ibis.table(
name="lineitem",
schema={
"l_orderkey": "int32",
"l_partkey": "int32",
"l_suppkey": "int32",
"l_linenumber": "int32",
"l_quantity": "decimal(15, 2)",
"l_extendedprice": "decimal(15, 2)",
"l_discount": "decimal(15, 2)",
"l_tax": "decimal(15, 2)",
"l_returnflag": "string",
"l_linestatus": "string",
"l_shipdate": "date",
"l_commitdate": "date",
"l_receiptdate": "date",
"l_shipinstruct": "string",
"l_shipmode": "string",
"l_comment": "string",
},
)
lit = ibis.literal(1)
f = lineitem.filter((lineitem.l_shipdate <= ibis.literal("1998-09-02").cast("date")))
multiply = f.l_extendedprice * ((lit - f.l_discount))
agg = f.aggregate(
[
f.l_quantity.sum().name("sum_qty"),
f.l_extendedprice.sum().name("sum_base_price"),
multiply.sum().name("sum_disc_price"),
((multiply) * ((lit + f.l_tax))).sum().name("sum_charge"),
f.l_quantity.mean().name("avg_qty"),
f.l_extendedprice.mean().name("avg_price"),
f.l_discount.mean().name("avg_disc"),
f.count().name("count_order"),
],
by=[f.l_returnflag, f.l_linestatus],
)

result = agg.order_by(
agg.l_returnflag.asc(nulls_first=True), agg.l_linestatus.asc(nulls_first=True)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import ibis


customer = ibis.table(
name="customer",
schema={
"c_custkey": "int64",
"c_name": "string",
"c_address": "string",
"c_nationkey": "int16",
"c_phone": "string",
"c_acctbal": "decimal",
"c_mktsegment": "string",
"c_comment": "string",
},
)
lit = ibis.literal(True)
orders = ibis.table(
name="orders",
schema={
"o_orderkey": "int64",
"o_custkey": "int64",
"o_orderstatus": "string",
"o_totalprice": "decimal(12, 2)",
"o_orderdate": "date",
"o_orderpriority": "string",
"o_clerk": "string",
"o_shippriority": "int32",
"o_comment": "string",
},
)
lineitem = ibis.table(
name="lineitem",
schema={
"l_orderkey": "int32",
"l_partkey": "int32",
"l_suppkey": "int32",
"l_linenumber": "int32",
"l_quantity": "decimal(15, 2)",
"l_extendedprice": "decimal(15, 2)",
"l_discount": "decimal(15, 2)",
"l_tax": "decimal(15, 2)",
"l_returnflag": "string",
"l_linestatus": "string",
"l_shipdate": "date",
"l_commitdate": "date",
"l_receiptdate": "date",
"l_shipinstruct": "string",
"l_shipmode": "string",
"l_comment": "string",
},
)
cast = ibis.literal("1995-03-15").cast("date")
joinchain = (
customer.inner_join(
orders,
[(customer.c_custkey == orders.o_custkey), lit, (orders.o_orderdate < cast)],
)
.inner_join(
lineitem,
[(orders.o_orderkey == lineitem.l_orderkey), lit, (lineitem.l_shipdate > cast)],
)
.select(
customer.c_custkey,
customer.c_name,
customer.c_address,
customer.c_nationkey,
customer.c_phone,
customer.c_acctbal,
customer.c_mktsegment,
customer.c_comment,
orders.o_orderkey,
orders.o_custkey,
orders.o_orderstatus,
orders.o_totalprice,
orders.o_orderdate,
orders.o_orderpriority,
orders.o_clerk,
orders.o_shippriority,
orders.o_comment,
lineitem.l_orderkey,
lineitem.l_partkey,
lineitem.l_suppkey,
lineitem.l_linenumber,
lineitem.l_quantity,
lineitem.l_extendedprice,
lineitem.l_discount,
lineitem.l_tax,
lineitem.l_returnflag,
lineitem.l_linestatus,
lineitem.l_shipdate,
lineitem.l_commitdate,
lineitem.l_receiptdate,
lineitem.l_shipinstruct,
lineitem.l_shipmode,
lineitem.l_comment,
)
)
f = joinchain.filter((joinchain.c_mktsegment == "BUILDING"))
agg = f.aggregate(
[(f.l_extendedprice * ((1 - f.l_discount))).sum().name("revenue")],
by=[f.l_orderkey, f.o_orderdate, f.o_shippriority],
)
s = agg.order_by(agg.revenue.desc(), agg.o_orderdate.asc(nulls_first=True))

result = s.select(s.l_orderkey, s.revenue, s.o_orderdate, s.o_shippriority).limit(10)
101 changes: 101 additions & 0 deletions ibis/backends/duckdb/tests/test_decompile_tpch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import annotations

import importlib
from contextlib import contextmanager
from pathlib import Path

import pytest
from pytest import param

import ibis
from ibis.backends.tests.tpc.conftest import compare_tpc_results
from ibis.formats.pandas import PandasData

tpch_catalog = {
"lineitem": {
"l_orderkey": "int32",
"l_partkey": "int32",
"l_suppkey": "int32",
"l_linenumber": "int32",
"l_quantity": "decimal(15, 2)",
"l_extendedprice": "decimal(15, 2)",
"l_discount": "decimal(15, 2)",
"l_tax": "decimal(15, 2)",
"l_returnflag": "string",
"l_linestatus": "string",
"l_shipdate": "date",
"l_commitdate": "date",
"l_receiptdate": "date",
"l_shipinstruct": "string",
"l_shipmode": "string",
"l_comment": "string",
},
"customer": [
("c_custkey", "int64"),
("c_name", "string"),
("c_address", "string"),
("c_nationkey", "int16"),
("c_phone", "string"),
("c_acctbal", "decimal"),
("c_mktsegment", "string"),
("c_comment", "string"),
],
"orders": [
("o_orderkey", "int64"),
("o_custkey", "int64"),
("o_orderstatus", "string"),
("o_totalprice", "decimal(12,2)"),
("o_orderdate", "date"),
("o_orderpriority", "string"),
("o_clerk", "string"),
("o_shippriority", "int32"),
("o_comment", "string"),
],
}

root = Path(__file__).absolute().parents[3]

SQL_QUERY_PATH = root / "backends" / "tests" / "tpc" / "queries" / "duckdb" / "h"


@contextmanager
def set_database(con, db):
olddb = con.current_database
con.raw_sql(f"USE {db}")
yield
con.raw_sql(f"USE {olddb}")


@pytest.mark.parametrize(
"tpch_query",
[
param(1, id="tpch01"),
param(3, id="tpch03"),
],
)
def test_parse_sql_tpch(tpch_query, snapshot, con, data_dir):
tpch_query_file = SQL_QUERY_PATH / f"{tpch_query:02d}.sql"
with open(tpch_query_file) as f:
sql = f.read()

expr = ibis.parse_sql(sql, tpch_catalog)
code = ibis.decompile(expr, format=True)
snapshot.assert_match(code, "out_tpch.py")

# Import just-created snapshot
SNAPSHOT_MODULE = f"ibis.backends.duckdb.tests.snapshots.test_decompile_tpch.test_parse_sql_tpch.tpch{tpch_query:02d}.out_tpch"
module = importlib.import_module(SNAPSHOT_MODULE)

with set_database(con, "tpch"):
# Get results from executing SQL directly on DuckDB
expected_df = con.con.execute(sql).df()
# Get results from decompiled ibis query
result_df = con.to_pandas(module.result)

# Then set the expected columns so we can coerce the datatypes
# of the pandas dataframe correctly
expected_df.columns = result_df.columns

expected_df = PandasData.convert_table(expected_df, module.result.schema())

compare_tpc_results(result_df, expected_df)
1 change: 1 addition & 0 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
typecode.BOOLEAN: dt.Boolean,
typecode.CHAR: dt.String,
typecode.DATE: dt.Date,
typecode.DATETIME: dt.Timestamp,
typecode.DATE32: dt.Date,
typecode.DOUBLE: dt.Float64,
typecode.ENUM: dt.String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

int_col_table = ibis.table(name="int_col_table", schema={"int_col": "int32"})

result = (int_col_table.int_col + 4).name("foo")
result = ((int_col_table.int_col + 4)).name("foo")
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
"month": "int32",
},
)
f = functional_alltypes.filter(functional_alltypes.bigint_col > 0)
f = functional_alltypes.filter((functional_alltypes.bigint_col > 0))

result = f.aggregate([f.int_col.nunique().name("nunique")], by=[f.string_col])
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))
difference = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
).difference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@


t = ibis.table(name="t", schema={"a": "int64", "b": "string"})
f = t.filter(t.b == "m")
f = t.filter((t.b == "m"))
agg = f.aggregate([f.a.sum().name("sum"), f.a.max()], by=[f.b])
f1 = agg.filter(agg["Max(a)"] == 2)
f1 = agg.filter((agg["Max(a)"] == 2))

result = f1.select(f1.b, f1.sum)
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))
intersection = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
).intersect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
)
param = ibis.param("timestamp")
f = alltypes.filter(alltypes.timestamp_col < param.name("my_param"))
f = alltypes.filter((alltypes.timestamp_col < param.name("my_param")))
agg = f.aggregate([f.float_col.sum().name("foo")], by=[f.string_col])

result = agg.foo.count()
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))

result = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
lit = ibis.timestamp("2018-01-01 00:00:00")
s = ibis.table(name="s", schema={"b": "string"})
t = ibis.table(name="t", schema={"a": "int64", "b": "string", "c": "timestamp"})
f = t.filter(t.c == lit)
f = t.filter((t.c == lit))
dropcolumns = f.select(f.a, f.b, f.c.name("C")).drop("C")
joinchain = (
dropcolumns.select(dropcolumns.a, dropcolumns.b, lit.name("the_date"))
.inner_join(
s,
dropcolumns.select(dropcolumns.a, dropcolumns.b, lit.name("the_date")).b == s.b,
(
dropcolumns.select(dropcolumns.a, dropcolumns.b, lit.name("the_date")).b
== s.b
),
)
.select(dropcolumns.select(dropcolumns.a, dropcolumns.b, lit.name("the_date")).a)
)

result = joinchain.filter(joinchain.a < 1.0)
result = joinchain.filter((joinchain.a < 1.0))
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))

result = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))

result = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))
union = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
).union(f1.select(f1.string_col.name("key"), f1.double_col.name("value")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

result = (
tpch_region.inner_join(
tpch_nation, tpch_region.r_regionkey == tpch_nation.n_regionkey
tpch_nation, (tpch_region.r_regionkey == tpch_nation.n_regionkey)
)
.select(
tpch_nation.n_nationkey,
Expand Down
Loading

0 comments on commit 0268044

Please sign in to comment.