Skip to content

Commit

Permalink
allow direct table execution
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Nov 11, 2022
1 parent 1d7c5e7 commit 62d3496
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
10 changes: 8 additions & 2 deletions sqlglot/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
logger = logging.getLogger("sqlglot")


def execute(sql, schema, read=None, tables=None):
def execute(sql, schema=None, read=None, tables=None):
"""
Run a sql query against data.
Expand All @@ -31,8 +31,14 @@ def execute(sql, schema, read=None, tables=None):
Returns:
sqlglot.executor.Table: Simple columnar data structure.
"""
schema = ensure_schema(schema)
tables = ensure_tables(tables)
if not schema:
# do a real type mapping one day
schema = {
name: {column: type(table[0][column]).__name__ for column in table.columns}
for name, table in tables.mapping.items()
}
schema = ensure_schema(schema)
if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args:
raise ExecuteError("Tables must support the same table args as schema")
expression = parse_one(sql, read=read)
Expand Down
35 changes: 35 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,41 @@ def test_execute_catalog_db_table(self):
assert result1.columns == result2.columns
assert result1.rows == result2.rows

def test_execute_tables(self):
tables = {
"sushi": [
{"id": 1, "price": 1.0},
{"id": 2, "price": 2.0},
{"id": 3, "price": 3.0},
],
"order_items": [
{"sushi_id": 1, "order_id": 1},
{"sushi_id": 1, "order_id": 1},
{"sushi_id": 2, "order_id": 1},
],
"orders": [
{"id": 1, "user_id": 1},
],
}

self.assertEqual(
execute(
"""
SELECT
o.user_id,
SUM(s.price) AS price
FROM orders o
JOIN order_items i
ON o.id = i.order_id
JOIN sushi s
ON i.sushi_id = s.id
GROUP BY o.user_id
""",
tables=tables,
).rows,
[(1, 4.0)],
)

def test_table_depth_mismatch(self):
tables = {"table": []}
schema = {"db": {"table": {"col": "VARCHAR"}}}
Expand Down

0 comments on commit 62d3496

Please sign in to comment.