Skip to content

Commit

Permalink
allow joining with same names
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Oct 31, 2022
1 parent c3ec257 commit 248a2ab
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 115 deletions.
45 changes: 33 additions & 12 deletions sqlglot/executor/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, tables, env=None):
env (Optional[dict]): dictionary of functions within the execution context
"""
self.tables = tables
self._table = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
Expand All @@ -29,38 +30,58 @@ def eval(self, code):
def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes)

@property
def table(self):
if self._table is None:
self._table = list(self.tables.values())[0]
for other in self.tables.values():
if self._table.columns != other.columns:
raise Exception(f"Columns are different.")
if len(self._table.rows) != len(other.rows):
raise Exception(f"Rows are different.")
return self._table

@property
def columns(self):
return self.table.columns

def __iter__(self):
return self.table_iter(list(self.tables)[0])
self.env["scope"] = self.row_readers
for i in range(len(self.table.rows)):
for table in self.tables.values():
reader = table[i]
yield reader, self


def table_iter(self, table):
self.env["scope"] = self.row_readers

for reader in self.tables[table]:
yield reader, self

def sort(self, table, key):
table = self.tables[table]
def sort(self, key):
table = self.table

def sort_key(row):
table.reader.row = row
return self.eval_tuple(key)

table.rows.sort(key=sort_key)

def set_row(self, table, row):
self.row_readers[table].row = row
def set_row(self, row):
for table in self.tables.values():
table.reader.row = row
self.env["scope"] = self.row_readers

def set_index(self, table, index):
self.row_readers[table].row = self.tables[table].rows[index]
def set_index(self, index):
for table in self.tables.values():
table[index]
self.env["scope"] = self.row_readers

def set_range(self, table, start, end):
self.range_readers[table].range = range(start, end)
def set_range(self, start, end):
for name in self.tables:
self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers

def __getitem__(self, table):
return self.env["scope"][table]

def __contains__(self, table):
return table in self.tables
169 changes: 81 additions & 88 deletions sqlglot/executor/python.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
import collections
import itertools
import math

from sqlglot import exp, planner
from sqlglot.dialects.dialect import Dialect, inline_array_sql
Expand Down Expand Up @@ -76,13 +77,10 @@ def table(self, expressions):
return Table(expression.alias_or_name for expression in expressions)

def scan(self, step, context):
if hasattr(step, "source"):
source = step.source
source = step.source

if isinstance(source, exp.Expression):
source = source.name or source.alias
else:
source = step.name
if isinstance(source, exp.Expression):
source = source.name or source.alias

condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
Expand All @@ -96,14 +94,12 @@ def scan(self, step, context):

if projections:
sink = self.table(step.projections)
elif source in context:
sink = Table(context[source].columns)
else:
sink = None

for reader, ctx in table_iter:
if sink is None:
sink = Table(ctx[source].columns)
sink = Table(reader.columns)

if condition and not ctx.eval(condition):
continue
Expand Down Expand Up @@ -135,98 +131,76 @@ def scan_csv(self, step):
types.append(type(ast.literal_eval(v)))
except (ValueError, SyntaxError):
types.append(str)
context.set_row(alias, tuple(t(v) for t, v in zip(types, row)))
yield context[alias], context
context.set_row(tuple(t(v) for t, v in zip(types, row)))
yield context.table.reader, context

def join(self, step, context):
source = step.name

join_context = self.context({source: context.tables[source]})

def merge_context(ctx, table):
# create a new context where all existing tables are mapped to a new one
return self.context({name: table for name in ctx.tables})
source_table = context.tables[source]
source_context = self.context({source: source_table})
column_ranges = {source: range(0, len(source_table.columns))}

for name, join in step.joins.items():
join_context = self.context({**join_context.tables, name: context.tables[name]})
table = context.tables[name]
start = max(r.stop for r in column_ranges.values())
column_ranges[name] = range(start, len(table.columns) + start)
join_context = self.context({name: table})

if join.get("source_key"):
table = self.hash_join(join, source, name, join_context)
table = self.hash_join(join, source_context, join_context)
else:
table = self.nested_loop_join(join, source, name, join_context)
table = self.nested_loop_join(join, source_context, join_context)

join_context = merge_context(join_context, table)

# apply projections or conditions
context = self.scan(step, join_context)
source_context = self.context(
{name: Table(table.columns, table.rows, column_range) for name, column_range in column_ranges.items()}
)

# use the scan context since it returns a single table
# otherwise there are no projections so all other tables are still in scope
if step.projections:
return context
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)

return merge_context(join_context, context.tables[source])
if not condition or not projections:
return source_context

def nested_loop_join(self, _join, a, b, context):
table = Table(context.tables[a].columns + context.tables[b].columns)
sink = self.table(step.projections if projections else source_context.columns)

for reader_a, _ in context.table_iter(a):
for reader_b, _ in context.table_iter(b):
table.append(reader_a.row + reader_b.row)
for reader, ctx in join_context:
if condition and not ctx.eval(condition):
continue

return table
if projections:
sink.append(ctx.eval_tuple(projections))
else:
sink.append(reader.row)

def hash_join(self, join, a, b, context):
a_key = self.generate_tuple(join["source_key"])
b_key = self.generate_tuple(join["join_key"])
if len(sink) >= step.limit:
break

results = collections.defaultdict(lambda: ([], []))
return self.context({step.name: sink})

for reader, ctx in context.table_iter(a):
results[ctx.eval_tuple(a_key)][0].append(reader.row)
for reader, ctx in context.table_iter(b):
results[ctx.eval_tuple(b_key)][1].append(reader.row)
def nested_loop_join(self, _join, source_context, join_context):
table = Table(source_context.columns + join_context.columns)

table = Table(context.tables[a].columns + context.tables[b].columns)
for a_group, b_group in results.values():
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
for reader_a, _ in source_context:
for reader_b, _ in join_context:
table.append(reader_a.row + reader_b.row)

return table

def sort_merge_join(self, join, a, b, context):
a_key = self.generate_tuple(join["source_key"])
b_key = self.generate_tuple(join["join_key"])

context.sort(a, a_key)
context.sort(b, b_key)

a_i = 0
b_i = 0
a_n = len(context.tables[a])
b_n = len(context.tables[b])

table = Table(context.tables[a].columns + context.tables[b].columns)
def hash_join(self, join, source_context, join_context):
source_key = self.generate_tuple(join["source_key"])
join_key = self.generate_tuple(join["join_key"])

def get_key(source, key, i):
context.set_index(source, i)
return context.eval_tuple(key)

while a_i < a_n and b_i < b_n:
key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i))

a_group = []

while a_i < a_n and key == get_key(a, a_key, a_i):
a_group.append(context[a].row)
a_i += 1
results = collections.defaultdict(lambda: ([], []))

b_group = []
for reader, ctx in source_context:
results[ctx.eval_tuple(source_key)][0].append(reader.row)
for reader, ctx in join_context:
results[ctx.eval_tuple(join_key)][1].append(reader.row)

while b_i < b_n and key == get_key(b, b_key, b_i):
b_group.append(context[b].row)
b_i += 1
table = Table(source_context.columns + join_context.columns)

for a_group, b_group in results.values():
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)

Expand All @@ -238,16 +212,16 @@ def aggregate(self, step, context):
aggregations = self.generate_tuple(step.aggregations)
operands = self.generate_tuple(step.operands)

context.sort(source, group_by)

if step.operands:
if operands:
source_table = context.tables[source]
operand_table = Table(source_table.columns + self.table(step.operands).columns)

for reader, ctx in context:
operand_table.append(reader.row + ctx.eval_tuple(operands))

context = self.context({source: operand_table})
context = self.context({None: operand_table, **{table: operand_table for table in context.tables}})

context.sort(group_by)

group = None
start = 0
Expand All @@ -256,29 +230,48 @@ def aggregate(self, step, context):
table = self.table(step.group + step.aggregations)

for i in range(length):
context.set_index(source, i)
context.set_index(i)
key = context.eval_tuple(group_by)
group = key if group is None else group
end += 1

if i == length - 1:
context.set_range(source, start, end - 1)
context.set_range(start, end - 1)
elif key != group:
context.set_range(source, start, end - 2)
context.set_range(start, end - 2)
else:
continue

table.append(group + context.eval_tuple(aggregations))
group = key
start = end - 2

return self.scan(step, self.context({source: table}))
context = self.context({step.name: table, **{name: table for name in context.tables}})

if step.projections:
return self.scan(step, context)
return context

def sort(self, step, context):
table = list(context.tables)[0]
key = self.generate_tuple(step.key)
context.sort(table, key)
return self.scan(step, context)
projections = self.generate_tuple(step.projections)

sink = self.table(step.projections)

for reader, ctx in context:
sink.append(ctx.eval_tuple(projections))

context = self.context(
{
None: sink,
**{table: sink for table in context.tables},
}
)
context.sort(self.generate_tuple(step.key))

if not math.isinf(step.limit):
context.table.rows = context.table.rows[0 : step.limit]

return self.context({step.name: context.table})


def _cast_py(self, expression):
Expand All @@ -293,7 +286,7 @@ def _cast_py(self, expression):


def _column_py(self, expression):
table = self.sql(expression, "table")
table = self.sql(expression, "table") or None
this = self.sql(expression, "this")
return f"scope[{table}][{this}]"

Expand Down
23 changes: 13 additions & 10 deletions sqlglot/executor/table.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
class Table:
def __init__(self, *columns, rows=None):
self.columns = tuple(columns if isinstance(columns[0], str) else columns[0])
def __init__(self, columns, rows=None, column_range=None):
self.columns = tuple(columns)
self.column_range = column_range
self.reader = RowReader(self.columns, self.column_range)

self.rows = rows or []
if rows:
assert len(rows[0]) == len(self.columns)
self.reader = RowReader(self.columns)
self.range_reader = RangeReader(self)

def append(self, row):
Expand All @@ -29,16 +31,17 @@ def __getitem__(self, index):
return self.reader

def __repr__(self):
widths = {column: len(column) for column in self.columns}
lines = [" ".join(column for column in self.columns)]
columns = tuple(
column for i, column in enumerate(self.columns) if not self.column_range or i in self.column_range
)
widths = {column: len(column) for column in columns}
lines = [" ".join(column for column in columns)]

for i, row in enumerate(self):
if i > 10:
break

lines.append(
" ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns)
)
lines.append(" ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns))
return "\n".join(lines)


Expand Down Expand Up @@ -70,8 +73,8 @@ def __getitem__(self, column):


class RowReader:
def __init__(self, columns):
self.columns = {column: i for i, column in enumerate(columns)}
def __init__(self, columns, column_range=None):
self.columns = {column: i for i, column in enumerate(columns) if not column_range or i in column_range}
self.row = None

def __getitem__(self, column):
Expand Down
Loading

0 comments on commit 248a2ab

Please sign in to comment.