diff --git a/tests/unit/sqlalchemy/test_compiler.py b/tests/unit/sqlalchemy/test_compiler.py index 00bd686c..9e4aad44 100644 --- a/tests/unit/sqlalchemy/test_compiler.py +++ b/tests/unit/sqlalchemy/test_compiler.py @@ -20,11 +20,12 @@ Table, ) from sqlalchemy.schema import CreateTable +from sqlalchemy.sql import column, table from trino.sqlalchemy.dialect import TrinoDialect metadata = MetaData() -table = Table( +table_without_catalog = Table( 'table', metadata, Column('id', Integer), @@ -45,26 +46,26 @@ def dialect(): def test_limit_offset(dialect): - statement = select(table).limit(10).offset(0) + statement = select(table_without_catalog).limit(10).offset(0) query = statement.compile(dialect=dialect) assert str(query) == 'SELECT "table".id, "table".name \nFROM "table"\nOFFSET :param_1\nLIMIT :param_2' def test_limit(dialect): - statement = select(table).limit(10) + statement = select(table_without_catalog).limit(10) query = statement.compile(dialect=dialect) assert str(query) == 'SELECT "table".id, "table".name \nFROM "table"\nLIMIT :param_1' def test_offset(dialect): - statement = select(table).offset(0) + statement = select(table_without_catalog).offset(0) query = statement.compile(dialect=dialect) assert str(query) == 'SELECT "table".id, "table".name \nFROM "table"\nOFFSET :param_1' def test_cte_insert_order(dialect): - cte = select(table).cte('cte') - statement = insert(table).from_select(table.columns, cte) + cte = select(table_without_catalog).cte('cte') + statement = insert(table_without_catalog).from_select(table_without_catalog.columns, cte) query = statement.compile(dialect=dialect) assert str(query) == \ 'INSERT INTO "table" (id, name) WITH cte AS \n'\ @@ -89,3 +90,9 @@ def test_catalogs_create_table(dialect): '\tid INTEGER\n'\ ')\n'\ '\n' + + +def test_table_clause(dialect): + statement = select(table("user", column("id"), column("name"), column("description"))) + query = statement.compile(dialect=dialect) + assert str(query) == 'SELECT user.id, user.name, user.description \nFROM user' diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index 99b272fd..0078e689 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -10,17 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from sqlalchemy.sql import compiler -try: - from sqlalchemy.sql.expression import ( - Alias, - CTE, - Subquery, - ) -except ImportError: - # For SQLAlchemy versions < 1.4, the CTE and Subquery classes did not explicitly exist - from sqlalchemy.sql.expression import Alias - CTE = type(None) - Subquery = type(None) +from sqlalchemy.sql.base import DialectKWArgs + # https://trino.io/docs/current/language/reserved.html RESERVED_WORDS = { @@ -122,10 +113,7 @@ def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, @staticmethod def add_catalog(sql, table): - if table is None: - return sql - - if isinstance(table, (Alias, CTE, Subquery)): + if table is None or not isinstance(table, DialectKWArgs): return sql if (