Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(decompile): make the decompiler run on TPCH query 1 #9779

Merged
merged 23 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a925189
feat(parse_sql): add support for sge.Cast and sge.DataType
gforsyth Jul 3, 2024
1d5b0de
feat(decompiler): qualify projections, sort keys, and groupbys
gforsyth Aug 6, 2024
229e0da
test(decompile): add decompile snapshot test for tpch1
gforsyth Aug 6, 2024
4519386
feat(decompile): add support for sge.Count
gforsyth Aug 6, 2024
c96bdd3
fix(decompile): wrap parens around all binary ops
gforsyth Aug 6, 2024
275f2a9
chore(decompile): rename snapshot file to avoid ruff
gforsyth Aug 6, 2024
d0173ff
feat(decompile): handle converting join filters
gforsyth Aug 6, 2024
b79c873
feat(datatypes): add mapping for typecode.DATETIME
gforsyth Aug 6, 2024
2d5d5f5
test(decompile): test tpch3 decom recom
gforsyth Aug 6, 2024
9253919
chore(snapshots): update snapshots w additional parens in binary ops
gforsyth Aug 7, 2024
46a2304
test(tpch): break out tpch results comparison into function
gforsyth Aug 7, 2024
83a8956
test(decompile_tpch): unify snapshot name and add ruff ignore
gforsyth Aug 7, 2024
242b196
fix(decompile): check that group keys exist, and merge multiple values
gforsyth Aug 7, 2024
538fb3f
fix(decompile): don't nuke join predicates if they already exist
gforsyth Aug 7, 2024
905401f
chore(decompile): update snapshots
gforsyth Aug 7, 2024
db6413d
chore(decompile): update backend sql snapshots
gforsyth Aug 7, 2024
89a55fd
fix(limit): include limits from sorts as well as scans
gforsyth Aug 8, 2024
741fe47
test(decompile): parametrize test and pull queries from sql files
gforsyth Aug 8, 2024
016dc25
test: move tpch decompile tests to separate file
gforsyth Aug 8, 2024
878b451
chore: move tests into duckdb backend tests
gforsyth Aug 8, 2024
2b0dad8
test(decompile): compare results of decompiled tpch to sql tpch
gforsyth Aug 8, 2024
df6ad4a
test(decompile): undo database switch after test
gforsyth Aug 8, 2024
2bbbd06
fix(decompile): use `_` for more robust sort key handling
gforsyth Aug 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import ibis


lineitem = ibis.table(
name="lineitem",
schema={
"l_orderkey": "int32",
"l_partkey": "int32",
"l_suppkey": "int32",
"l_linenumber": "int32",
"l_quantity": "decimal(15, 2)",
"l_extendedprice": "decimal(15, 2)",
"l_discount": "decimal(15, 2)",
"l_tax": "decimal(15, 2)",
"l_returnflag": "string",
"l_linestatus": "string",
"l_shipdate": "date",
"l_commitdate": "date",
"l_receiptdate": "date",
"l_shipinstruct": "string",
"l_shipmode": "string",
"l_comment": "string",
},
)
lit = ibis.literal(1)
f = lineitem.filter((lineitem.l_shipdate <= ibis.literal("1998-09-02").cast("date")))
multiply = f.l_extendedprice * ((lit - f.l_discount))
agg = f.aggregate(
[
f.l_quantity.sum().name("sum_qty"),
f.l_extendedprice.sum().name("sum_base_price"),
multiply.sum().name("sum_disc_price"),
((multiply) * ((lit + f.l_tax))).sum().name("sum_charge"),
f.l_quantity.mean().name("avg_qty"),
f.l_extendedprice.mean().name("avg_price"),
f.l_discount.mean().name("avg_disc"),
f.count().name("count_order"),
],
by=[f.l_returnflag, f.l_linestatus],
)

result = agg.order_by(
agg.l_returnflag.asc(nulls_first=True), agg.l_linestatus.asc(nulls_first=True)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import ibis


customer = ibis.table(
name="customer",
schema={
"c_custkey": "int64",
"c_name": "string",
"c_address": "string",
"c_nationkey": "int16",
"c_phone": "string",
"c_acctbal": "decimal",
"c_mktsegment": "string",
"c_comment": "string",
},
)
lit = ibis.literal(True)
orders = ibis.table(
name="orders",
schema={
"o_orderkey": "int64",
"o_custkey": "int64",
"o_orderstatus": "string",
"o_totalprice": "decimal(12, 2)",
"o_orderdate": "date",
"o_orderpriority": "string",
"o_clerk": "string",
"o_shippriority": "int32",
"o_comment": "string",
},
)
lineitem = ibis.table(
name="lineitem",
schema={
"l_orderkey": "int32",
"l_partkey": "int32",
"l_suppkey": "int32",
"l_linenumber": "int32",
"l_quantity": "decimal(15, 2)",
"l_extendedprice": "decimal(15, 2)",
"l_discount": "decimal(15, 2)",
"l_tax": "decimal(15, 2)",
"l_returnflag": "string",
"l_linestatus": "string",
"l_shipdate": "date",
"l_commitdate": "date",
"l_receiptdate": "date",
"l_shipinstruct": "string",
"l_shipmode": "string",
"l_comment": "string",
},
)
cast = ibis.literal("1995-03-15").cast("date")
joinchain = (
customer.inner_join(
orders,
[(customer.c_custkey == orders.o_custkey), lit, (orders.o_orderdate < cast)],
)
.inner_join(
lineitem,
[(orders.o_orderkey == lineitem.l_orderkey), lit, (lineitem.l_shipdate > cast)],
)
.select(
customer.c_custkey,
customer.c_name,
customer.c_address,
customer.c_nationkey,
customer.c_phone,
customer.c_acctbal,
customer.c_mktsegment,
customer.c_comment,
orders.o_orderkey,
orders.o_custkey,
orders.o_orderstatus,
orders.o_totalprice,
orders.o_orderdate,
orders.o_orderpriority,
orders.o_clerk,
orders.o_shippriority,
orders.o_comment,
lineitem.l_orderkey,
lineitem.l_partkey,
lineitem.l_suppkey,
lineitem.l_linenumber,
lineitem.l_quantity,
lineitem.l_extendedprice,
lineitem.l_discount,
lineitem.l_tax,
lineitem.l_returnflag,
lineitem.l_linestatus,
lineitem.l_shipdate,
lineitem.l_commitdate,
lineitem.l_receiptdate,
lineitem.l_shipinstruct,
lineitem.l_shipmode,
lineitem.l_comment,
)
)
f = joinchain.filter((joinchain.c_mktsegment == "BUILDING"))
agg = f.aggregate(
[(f.l_extendedprice * ((1 - f.l_discount))).sum().name("revenue")],
by=[f.l_orderkey, f.o_orderdate, f.o_shippriority],
)
s = agg.order_by(agg.revenue.desc(), agg.o_orderdate.asc(nulls_first=True))

result = s.select(s.l_orderkey, s.revenue, s.o_orderdate, s.o_shippriority).limit(10)
101 changes: 101 additions & 0 deletions ibis/backends/duckdb/tests/test_decompile_tpch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import annotations

Check warning on line 1 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L1

Added line #L1 was not covered by tests

import importlib
from contextlib import contextmanager
from pathlib import Path

Check warning on line 5 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L3-L5

Added lines #L3 - L5 were not covered by tests

import pytest
from pytest import param

Check warning on line 8 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L7-L8

Added lines #L7 - L8 were not covered by tests

import ibis
from ibis.backends.tests.tpc.conftest import compare_tpc_results
from ibis.formats.pandas import PandasData

Check warning on line 12 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L10-L12

Added lines #L10 - L12 were not covered by tests

tpch_catalog = {

Check warning on line 14 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L14

Added line #L14 was not covered by tests
"lineitem": {
"l_orderkey": "int32",
"l_partkey": "int32",
"l_suppkey": "int32",
"l_linenumber": "int32",
"l_quantity": "decimal(15, 2)",
"l_extendedprice": "decimal(15, 2)",
"l_discount": "decimal(15, 2)",
"l_tax": "decimal(15, 2)",
"l_returnflag": "string",
"l_linestatus": "string",
"l_shipdate": "date",
"l_commitdate": "date",
"l_receiptdate": "date",
"l_shipinstruct": "string",
"l_shipmode": "string",
"l_comment": "string",
},
"customer": [
("c_custkey", "int64"),
("c_name", "string"),
("c_address", "string"),
("c_nationkey", "int16"),
("c_phone", "string"),
("c_acctbal", "decimal"),
("c_mktsegment", "string"),
("c_comment", "string"),
],
"orders": [
("o_orderkey", "int64"),
("o_custkey", "int64"),
("o_orderstatus", "string"),
("o_totalprice", "decimal(12,2)"),
("o_orderdate", "date"),
("o_orderpriority", "string"),
("o_clerk", "string"),
("o_shippriority", "int32"),
("o_comment", "string"),
],
}

root = Path(__file__).absolute().parents[3]

Check warning on line 56 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L56

Added line #L56 was not covered by tests

SQL_QUERY_PATH = root / "backends" / "tests" / "tpc" / "queries" / "duckdb" / "h"

Check warning on line 58 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L58

Added line #L58 was not covered by tests


@contextmanager

Check warning on line 61 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L61

Added line #L61 was not covered by tests
def set_database(con, db):
olddb = con.current_database
con.raw_sql(f"USE {db}")
yield
con.raw_sql(f"USE {olddb}")

Check warning on line 66 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L63-L66

Added lines #L63 - L66 were not covered by tests


@pytest.mark.parametrize(

Check warning on line 69 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L69

Added line #L69 was not covered by tests
"tpch_query",
[
param(1, id="tpch01"),
param(3, id="tpch03"),
],
)
def test_parse_sql_tpch(tpch_query, snapshot, con, data_dir):
tpch_query_file = SQL_QUERY_PATH / f"{tpch_query:02d}.sql"

Check warning on line 77 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L77

Added line #L77 was not covered by tests
with open(tpch_query_file) as f:
sql = f.read()

Check warning on line 79 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L79

Added line #L79 was not covered by tests

expr = ibis.parse_sql(sql, tpch_catalog)
code = ibis.decompile(expr, format=True)
snapshot.assert_match(code, "out_tpch.py")

Check warning on line 83 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L81-L83

Added lines #L81 - L83 were not covered by tests

# Import just-created snapshot
SNAPSHOT_MODULE = f"ibis.backends.duckdb.tests.snapshots.test_decompile_tpch.test_parse_sql_tpch.tpch{tpch_query:02d}.out_tpch"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yikes 😬

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan on making this better in a follow up, I promise.

module = importlib.import_module(SNAPSHOT_MODULE)

Check warning on line 87 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L86-L87

Added lines #L86 - L87 were not covered by tests

with set_database(con, "tpch"):
# Get results from executing SQL directly on DuckDB
expected_df = con.con.execute(sql).df()

Check warning on line 91 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L91

Added line #L91 was not covered by tests
# Get results from decompiled ibis query
result_df = con.to_pandas(module.result)

Check warning on line 93 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L93

Added line #L93 was not covered by tests

# Then set the expected columns so we can coerce the datatypes
# of the pandas dataframe correctly
expected_df.columns = result_df.columns

Check warning on line 97 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L97

Added line #L97 was not covered by tests

expected_df = PandasData.convert_table(expected_df, module.result.schema())

Check warning on line 99 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L99

Added line #L99 was not covered by tests

compare_tpc_results(result_df, expected_df)

Check warning on line 101 in ibis/backends/duckdb/tests/test_decompile_tpch.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_decompile_tpch.py#L101

Added line #L101 was not covered by tests
1 change: 1 addition & 0 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
typecode.BOOLEAN: dt.Boolean,
typecode.CHAR: dt.String,
typecode.DATE: dt.Date,
typecode.DATETIME: dt.Timestamp,
typecode.DATE32: dt.Date,
typecode.DOUBLE: dt.Float64,
typecode.ENUM: dt.String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

int_col_table = ibis.table(name="int_col_table", schema={"int_col": "int32"})

result = (int_col_table.int_col + 4).name("foo")
result = ((int_col_table.int_col + 4)).name("foo")
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
"month": "int32",
},
)
f = functional_alltypes.filter(functional_alltypes.bigint_col > 0)
f = functional_alltypes.filter((functional_alltypes.bigint_col > 0))

result = f.aggregate([f.int_col.nunique().name("nunique")], by=[f.string_col])
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))
difference = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
).difference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@


t = ibis.table(name="t", schema={"a": "int64", "b": "string"})
f = t.filter(t.b == "m")
f = t.filter((t.b == "m"))
agg = f.aggregate([f.a.sum().name("sum"), f.a.max()], by=[f.b])
f1 = agg.filter(agg["Max(a)"] == 2)
f1 = agg.filter((agg["Max(a)"] == 2))

result = f1.select(f1.b, f1.sum)
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))
intersection = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
).intersect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
)
param = ibis.param("timestamp")
f = alltypes.filter(alltypes.timestamp_col < param.name("my_param"))
f = alltypes.filter((alltypes.timestamp_col < param.name("my_param")))
agg = f.aggregate([f.float_col.sum().name("foo")], by=[f.string_col])

result = agg.foo.count()
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))

result = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
lit = ibis.timestamp("2018-01-01 00:00:00")
s = ibis.table(name="s", schema={"b": "string"})
t = ibis.table(name="t", schema={"a": "int64", "b": "string", "c": "timestamp"})
f = t.filter(t.c == lit)
f = t.filter((t.c == lit))
dropcolumns = f.select(f.a, f.b, f.c.name("C")).drop("C")
joinchain = (
dropcolumns.select(dropcolumns.a, dropcolumns.b, lit.name("the_date"))
.inner_join(
s,
dropcolumns.select(dropcolumns.a, dropcolumns.b, lit.name("the_date")).b == s.b,
(
dropcolumns.select(dropcolumns.a, dropcolumns.b, lit.name("the_date")).b
== s.b
),
)
.select(dropcolumns.select(dropcolumns.a, dropcolumns.b, lit.name("the_date")).a)
)

result = joinchain.filter(joinchain.a < 1.0)
result = joinchain.filter((joinchain.a < 1.0))
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))

result = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))

result = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))
union = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
).union(f1.select(f1.string_col.name("key"), f1.double_col.name("value")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

result = (
tpch_region.inner_join(
tpch_nation, tpch_region.r_regionkey == tpch_nation.n_regionkey
tpch_nation, (tpch_region.r_regionkey == tpch_nation.n_regionkey)
)
.select(
tpch_nation.n_nationkey,
Expand Down
Loading