From 62d3496e761a4f38dfa61af793062690923dce74 Mon Sep 17 00:00:00 2001 From: tobymao Date: Thu, 10 Nov 2022 20:51:48 -0800 Subject: [PATCH] allow direct table execution --- sqlglot/executor/__init__.py | 10 ++++++++-- tests/test_executor.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index 43116730ba..35cd86f67d 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -13,7 +13,7 @@ logger = logging.getLogger("sqlglot") -def execute(sql, schema, read=None, tables=None): +def execute(sql, schema=None, read=None, tables=None): """ Run a sql query against data. @@ -31,8 +31,14 @@ def execute(sql, schema, read=None, tables=None): Returns: sqlglot.executor.Table: Simple columnar data structure. """ - schema = ensure_schema(schema) tables = ensure_tables(tables) + if not schema: + # do a real type mapping one day + schema = { + name: {column: type(table[0][column]).__name__ for column in table.columns} + for name, table in tables.mapping.items() + } + schema = ensure_schema(schema) if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args: raise ExecuteError("Tables must support the same table args as schema") expression = parse_one(sql, read=read) diff --git a/tests/test_executor.py b/tests/test_executor.py index 049f479f09..f8daf84240 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -157,6 +157,41 @@ def test_execute_catalog_db_table(self): assert result1.columns == result2.columns assert result1.rows == result2.rows + def test_execute_tables(self): + tables = { + "sushi": [ + {"id": 1, "price": 1.0}, + {"id": 2, "price": 2.0}, + {"id": 3, "price": 3.0}, + ], + "order_items": [ + {"sushi_id": 1, "order_id": 1}, + {"sushi_id": 1, "order_id": 1}, + {"sushi_id": 2, "order_id": 1}, + ], + "orders": [ + {"id": 1, "user_id": 1}, + ], + } + + self.assertEqual( + execute( + """ + SELECT + o.user_id, + SUM(s.price) AS price + FROM orders o + JOIN order_items i + ON o.id = i.order_id + JOIN sushi s + ON i.sushi_id = s.id + GROUP BY o.user_id + """, + tables=tables, + ).rows, + [(1, 4.0)], + ) + def test_table_depth_mismatch(self): tables = {"table": []} schema = {"db": {"table": {"col": "VARCHAR"}}}