diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index d265a2c621..e9cb6bb86f 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -19,6 +19,7 @@ def __init__(self, tables, env=None): env (Optional[dict]): dictionary of functions within the execution context """ self.tables = tables + self._table = None self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()} self.env = {**(env or {}), "scope": self.row_readers} @@ -29,8 +30,28 @@ def eval(self, code): def eval_tuple(self, codes): return tuple(self.eval(code) for code in codes) + @property + def table(self): + if self._table is None: + self._table = list(self.tables.values())[0] + for other in self.tables.values(): + if self._table.columns != other.columns: + raise Exception(f"Columns are different.") + if len(self._table.rows) != len(other.rows): + raise Exception(f"Rows are different.") + return self._table + + @property + def columns(self): + return self.table.columns + def __iter__(self): - return self.table_iter(list(self.tables)[0]) + self.env["scope"] = self.row_readers + for i in range(len(self.table.rows)): + for table in self.tables.values(): + reader = table[i] + yield reader, self + def table_iter(self, table): self.env["scope"] = self.row_readers @@ -38,8 +59,8 @@ def table_iter(self, table): for reader in self.tables[table]: yield reader, self - def sort(self, table, key): - table = self.tables[table] + def sort(self, key): + table = self.table def sort_key(row): table.reader.row = row @@ -47,20 +68,20 @@ def sort_key(row): table.rows.sort(key=sort_key) - def set_row(self, table, row): - self.row_readers[table].row = row + def set_row(self, row): + for table in self.tables.values(): + table.reader.row = row self.env["scope"] = self.row_readers - def set_index(self, table, index): - self.row_readers[table].row = self.tables[table].rows[index] + def set_index(self, index): + for table in self.tables.values(): + table[index] self.env["scope"] = self.row_readers - def set_range(self, table, start, end): - self.range_readers[table].range = range(start, end) + def set_range(self, start, end): + for name in self.tables: + self.range_readers[name].range = range(start, end) self.env["scope"] = self.range_readers - def __getitem__(self, table): - return self.env["scope"][table] - def __contains__(self, table): return table in self.tables diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 1122ef28e1..612a331e2a 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -1,6 +1,7 @@ import ast import collections import itertools +import math from sqlglot import exp, planner from sqlglot.dialects.dialect import Dialect, inline_array_sql @@ -76,13 +77,10 @@ def table(self, expressions): return Table(expression.alias_or_name for expression in expressions) def scan(self, step, context): - if hasattr(step, "source"): - source = step.source + source = step.source - if isinstance(source, exp.Expression): - source = source.name or source.alias - else: - source = step.name + if isinstance(source, exp.Expression): + source = source.name or source.alias condition = self.generate(step.condition) projections = self.generate_tuple(step.projections) @@ -96,14 +94,12 @@ def scan(self, step, context): if projections: sink = self.table(step.projections) - elif source in context: - sink = Table(context[source].columns) else: sink = None for reader, ctx in table_iter: if sink is None: - sink = Table(ctx[source].columns) + sink = Table(reader.columns) if condition and not ctx.eval(condition): continue @@ -135,98 +131,76 @@ def scan_csv(self, step): types.append(type(ast.literal_eval(v))) except (ValueError, SyntaxError): types.append(str) - context.set_row(alias, tuple(t(v) for t, v in zip(types, row))) - yield context[alias], context + context.set_row(tuple(t(v) for t, v in zip(types, row))) + yield context.table.reader, context def join(self, step, context): source = step.name - join_context = self.context({source: context.tables[source]}) - - def merge_context(ctx, table): - # create a new context where all existing tables are mapped to a new one - return self.context({name: table for name in ctx.tables}) + source_table = context.tables[source] + source_context = self.context({source: source_table}) + column_ranges = {source: range(0, len(source_table.columns))} for name, join in step.joins.items(): - join_context = self.context({**join_context.tables, name: context.tables[name]}) + table = context.tables[name] + start = max(r.stop for r in column_ranges.values()) + column_ranges[name] = range(start, len(table.columns) + start) + join_context = self.context({name: table}) if join.get("source_key"): - table = self.hash_join(join, source, name, join_context) + table = self.hash_join(join, source_context, join_context) else: - table = self.nested_loop_join(join, source, name, join_context) + table = self.nested_loop_join(join, source_context, join_context) - join_context = merge_context(join_context, table) - - # apply projections or conditions - context = self.scan(step, join_context) + source_context = self.context( + {name: Table(table.columns, table.rows, column_range) for name, column_range in column_ranges.items()} + ) - # use the scan context since it returns a single table - # otherwise there are no projections so all other tables are still in scope - if step.projections: - return context + condition = self.generate(step.condition) + projections = self.generate_tuple(step.projections) - return merge_context(join_context, context.tables[source]) + if not condition or not projections: + return source_context - def nested_loop_join(self, _join, a, b, context): - table = Table(context.tables[a].columns + context.tables[b].columns) + sink = self.table(step.projections if projections else source_context.columns) - for reader_a, _ in context.table_iter(a): - for reader_b, _ in context.table_iter(b): - table.append(reader_a.row + reader_b.row) + for reader, ctx in join_context: + if condition and not ctx.eval(condition): + continue - return table + if projections: + sink.append(ctx.eval_tuple(projections)) + else: + sink.append(reader.row) - def hash_join(self, join, a, b, context): - a_key = self.generate_tuple(join["source_key"]) - b_key = self.generate_tuple(join["join_key"]) + if len(sink) >= step.limit: + break - results = collections.defaultdict(lambda: ([], [])) + return self.context({step.name: sink}) - for reader, ctx in context.table_iter(a): - results[ctx.eval_tuple(a_key)][0].append(reader.row) - for reader, ctx in context.table_iter(b): - results[ctx.eval_tuple(b_key)][1].append(reader.row) + def nested_loop_join(self, _join, source_context, join_context): + table = Table(source_context.columns + join_context.columns) - table = Table(context.tables[a].columns + context.tables[b].columns) - for a_group, b_group in results.values(): - for a_row, b_row in itertools.product(a_group, b_group): - table.append(a_row + b_row) + for reader_a, _ in source_context: + for reader_b, _ in join_context: + table.append(reader_a.row + reader_b.row) return table - def sort_merge_join(self, join, a, b, context): - a_key = self.generate_tuple(join["source_key"]) - b_key = self.generate_tuple(join["join_key"]) - - context.sort(a, a_key) - context.sort(b, b_key) - - a_i = 0 - b_i = 0 - a_n = len(context.tables[a]) - b_n = len(context.tables[b]) - - table = Table(context.tables[a].columns + context.tables[b].columns) + def hash_join(self, join, source_context, join_context): + source_key = self.generate_tuple(join["source_key"]) + join_key = self.generate_tuple(join["join_key"]) - def get_key(source, key, i): - context.set_index(source, i) - return context.eval_tuple(key) - - while a_i < a_n and b_i < b_n: - key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i)) - - a_group = [] - - while a_i < a_n and key == get_key(a, a_key, a_i): - a_group.append(context[a].row) - a_i += 1 + results = collections.defaultdict(lambda: ([], [])) - b_group = [] + for reader, ctx in source_context: + results[ctx.eval_tuple(source_key)][0].append(reader.row) + for reader, ctx in join_context: + results[ctx.eval_tuple(join_key)][1].append(reader.row) - while b_i < b_n and key == get_key(b, b_key, b_i): - b_group.append(context[b].row) - b_i += 1 + table = Table(source_context.columns + join_context.columns) + for a_group, b_group in results.values(): for a_row, b_row in itertools.product(a_group, b_group): table.append(a_row + b_row) @@ -238,16 +212,16 @@ def aggregate(self, step, context): aggregations = self.generate_tuple(step.aggregations) operands = self.generate_tuple(step.operands) - context.sort(source, group_by) - - if step.operands: + if operands: source_table = context.tables[source] operand_table = Table(source_table.columns + self.table(step.operands).columns) for reader, ctx in context: operand_table.append(reader.row + ctx.eval_tuple(operands)) - context = self.context({source: operand_table}) + context = self.context({None: operand_table, **{table: operand_table for table in context.tables}}) + + context.sort(group_by) group = None start = 0 @@ -256,15 +230,15 @@ def aggregate(self, step, context): table = self.table(step.group + step.aggregations) for i in range(length): - context.set_index(source, i) + context.set_index(i) key = context.eval_tuple(group_by) group = key if group is None else group end += 1 if i == length - 1: - context.set_range(source, start, end - 1) + context.set_range(start, end - 1) elif key != group: - context.set_range(source, start, end - 2) + context.set_range(start, end - 2) else: continue @@ -272,13 +246,32 @@ def aggregate(self, step, context): group = key start = end - 2 - return self.scan(step, self.context({source: table})) + context = self.context({step.name: table, **{name: table for name in context.tables}}) + + if step.projections: + return self.scan(step, context) + return context def sort(self, step, context): - table = list(context.tables)[0] - key = self.generate_tuple(step.key) - context.sort(table, key) - return self.scan(step, context) + projections = self.generate_tuple(step.projections) + + sink = self.table(step.projections) + + for reader, ctx in context: + sink.append(ctx.eval_tuple(projections)) + + context = self.context( + { + None: sink, + **{table: sink for table in context.tables}, + } + ) + context.sort(self.generate_tuple(step.key)) + + if not math.isinf(step.limit): + context.table.rows = context.table.rows[0 : step.limit] + + return self.context({step.name: context.table}) def _cast_py(self, expression): @@ -293,7 +286,7 @@ def _cast_py(self, expression): def _column_py(self, expression): - table = self.sql(expression, "table") + table = self.sql(expression, "table") or None this = self.sql(expression, "this") return f"scope[{table}][{this}]" diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index 80674cb5e7..ca2760aaff 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -1,10 +1,12 @@ class Table: - def __init__(self, *columns, rows=None): - self.columns = tuple(columns if isinstance(columns[0], str) else columns[0]) + def __init__(self, columns, rows=None, column_range=None): + self.columns = tuple(columns) + self.column_range = column_range + self.reader = RowReader(self.columns, self.column_range) + self.rows = rows or [] if rows: assert len(rows[0]) == len(self.columns) - self.reader = RowReader(self.columns) self.range_reader = RangeReader(self) def append(self, row): @@ -29,16 +31,17 @@ def __getitem__(self, index): return self.reader def __repr__(self): - widths = {column: len(column) for column in self.columns} - lines = [" ".join(column for column in self.columns)] + columns = tuple( + column for i, column in enumerate(self.columns) if not self.column_range or i in self.column_range + ) + widths = {column: len(column) for column in columns} + lines = [" ".join(column for column in columns)] for i, row in enumerate(self): if i > 10: break - lines.append( - " ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns) - ) + lines.append(" ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns)) return "\n".join(lines) @@ -70,8 +73,8 @@ def __getitem__(self, column): class RowReader: - def __init__(self, columns): - self.columns = {column: i for i, column in enumerate(columns)} + def __init__(self, columns, column_range=None): + self.columns = {column: i for i, column in enumerate(columns) if not column_range or i in column_range} self.row = None def __getitem__(self, column): diff --git a/sqlglot/planner.py b/sqlglot/planner.py index ea995d8fd6..c9bd72535a 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -102,7 +102,7 @@ def from_expression(cls, expression, ctes=None): continue if operand not in operands: operands[operand] = f"_a_{next(sequence)}" - operand.replace(exp.column(operands[operand], step.name, quoted=True)) + operand.replace(exp.column(operands[operand], quoted=True)) else: projections.append(e) @@ -119,7 +119,7 @@ def from_expression(cls, expression, ctes=None): aggregate.name = step.name aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items()) aggregate.aggregations = aggregations - aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions] + aggregate.group = group.expressions aggregate.add_dependency(step) step = aggregate @@ -136,9 +136,6 @@ def from_expression(cls, expression, ctes=None): sort.key = order.expressions sort.add_dependency(step) step = sort - for k in sort.key + projections: - for column in k.find_all(exp.Column): - column.set("table", exp.to_identifier(step.name, quoted=True)) step.projections = projections