Skip to content

Commit

Permalink
Toby/subqueries (#762)
Browse files Browse the repository at this point in the history
* unnest scalar subqueries as cross joins

fixes #748

* pass up to tpc-h 17

* more tests
  • Loading branch information
tobymao authored Nov 25, 2022
1 parent 4373ad8 commit a943117
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 9 deletions.
1 change: 0 additions & 1 deletion sqlglot/executor/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def interval(this, unit):


ENV = {
"__builtins__": {},
"exp": exp,
# aggs
"SUM": filter_nulls(sum),
Expand Down
1 change: 1 addition & 0 deletions sqlglot/executor/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ class Generator(generator.Generator):
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
exp.Is: lambda self, e: self.binary(e, "is"),
Expand Down
7 changes: 1 addition & 6 deletions sqlglot/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,7 @@ def from_expression(
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
table = expression
alias_ = expression.alias

if not alias_:
raise UnsupportedError(
"Tables/Subqueries must be aliased. Run it through the optimizer"
)
alias_ = expression.alias_or_name

if isinstance(expression, exp.Subquery):
table = expression.this
Expand Down
10 changes: 8 additions & 2 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def test_optimized_tpch(self):

def test_execute_tpch(self):
def to_csv(expression):
if isinstance(expression, exp.Table):
if isinstance(expression, exp.Table) and expression.name not in ("revenue"):
return parse_one(
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
)
return expression

for i, (sql, _) in enumerate(self.sqls[0:14]):
for i, (sql, _) in enumerate(self.sqls[0:16]):
with self.subTest(f"tpch-h {i + 1}"):
a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
Expand Down Expand Up @@ -484,6 +484,12 @@ def test_scalar_functions(self):
("IF(false, 1, 0)", 0),
("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
("1 IN (1, 2, 3)", True),
("1 IN (2, 3)", False),
("NULL IS NULL", True),
("NULL IS NOT NULL", False),
("NULL = NULL", None),
("NULL <> NULL", None),
]:
with self.subTest(sql):
result = execute(f"SELECT {sql}")
Expand Down

0 comments on commit a943117

Please sign in to comment.