Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding compatibility with SQLAlchemy 2.0 #457

Merged
merged 1 commit into from
Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ First install this package to register it with SQLAlchemy (see ``setup.py``).
# Presto
engine = create_engine('presto://localhost:8080/hive/default')
# Trino
engine = create_engine('trino://localhost:8080/hive/default')
engine = create_engine('trino+pyhive://localhost:8080/hive/default')
# Hive
engine = create_engine('hive://localhost:10000/default')

# SQLAlchemy < 2.0
logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
print select([func.count('*')], from_obj=logs).scalar()

Expand All @@ -82,6 +84,20 @@ First install this package to register it with SQLAlchemy (see ``setup.py``).
logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
print select([func.count('*')], from_obj=logs).scalar()

# SQLAlchemy >= 2.0
metadata_obj = MetaData()
books = Table("books", metadata_obj, Column("id", Integer), Column("title", String), Column("primary_author", String))
metadata_obj.create_all(engine)
inspector = inspect(engine)
inspector.get_columns('books')

with engine.connect() as con:
data = [{ "id": 1, "title": "The Hobbit", "primary_author": "Tolkien" },
{ "id": 2, "title": "The Silmarillion", "primary_author": "Tolkien" }]
con.execute(books.insert(), data[0])
result = con.execute(text("select * from books"))
print(result.fetchall())

Note: query generation functionality is not exhaustive or fully tested, but there should be no
problem with raw SQL.

Expand Down
30 changes: 23 additions & 7 deletions pyhive/sqlalchemy_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,22 @@

import re
from sqlalchemy import exc
from sqlalchemy import processors
from sqlalchemy.sql import text
try:
from sqlalchemy import processors
except ImportError:
# Required for SQLAlchemy>=2.0
from sqlalchemy.engine import processors
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
try:
from sqlalchemy.databases import mysql
mysql_tinyinteger = mysql.MSTinyInteger
except ImportError:
# Required for SQLAlchemy>2.0
from sqlalchemy.dialects import mysql
mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
Expand Down Expand Up @@ -121,7 +132,7 @@ def __init__(self, dialect):

_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'int': types.Integer,
'bigint': types.BigInteger,
Expand Down Expand Up @@ -247,10 +258,15 @@ class HiveDialect(default.DefaultDialect):
supports_multivalues_insert = True
type_compiler = HiveTypeCompiler
supports_sane_rowcount = False
supports_statement_cache = False

@classmethod
def dbapi(cls):
return hive

@classmethod
def import_dbapi(cls):
mdeshmu marked this conversation as resolved.
Show resolved Hide resolved
return hive

def create_connect_args(self, url):
kwargs = {
Expand All @@ -265,7 +281,7 @@ def create_connect_args(self, url):

def get_schema_names(self, connection, **kw):
# Equivalent to SHOW DATABASES
return [row[0] for row in connection.execute('SHOW SCHEMAS')]
return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))]
mdeshmu marked this conversation as resolved.
Show resolved Hide resolved

def get_view_names(self, connection, schema=None, **kw):
# Hive does not provide functionality to query tableType
Expand All @@ -280,7 +296,7 @@ def _get_table_columns(self, connection, table_name, schema):
# Using DESCRIBE works but is uglier.
try:
# This needs the table name to be unescaped (no backticks).
rows = connection.execute('DESCRIBE {}'.format(full_table)).fetchall()
rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall()
except exc.OperationalError as e:
# Does the table exist?
regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}'
Expand All @@ -296,7 +312,7 @@ def _get_table_columns(self, connection, table_name, schema):
raise exc.NoSuchTableError(full_table)
return rows

def has_table(self, connection, table_name, schema=None):
def has_table(self, connection, table_name, schema=None, **kw):
mdeshmu marked this conversation as resolved.
Show resolved Hide resolved
try:
self._get_table_columns(connection, table_name, schema)
return True
Expand Down Expand Up @@ -361,7 +377,7 @@ def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
if schema:
query += ' IN ' + self.identifier_preparer.quote_identifier(schema)
return [row[0] for row in connection.execute(query)]
return [row[0] for row in connection.execute(text(query))]

def do_rollback(self, dbapi_connection):
# No transactions for Hive
Expand Down
28 changes: 22 additions & 6 deletions pyhive/sqlalchemy_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,27 @@
from __future__ import unicode_literals

import re
import sqlalchemy
from sqlalchemy import exc
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
from sqlalchemy.sql import text
try:
from sqlalchemy.databases import mysql
mysql_tinyinteger = mysql.MSTinyInteger
except ImportError:
# Required for SQLAlchemy>=2.0
from sqlalchemy.dialects import mysql
mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler

from pyhive import presto
from pyhive.common import UniversalSet

sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))

class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
# Just quote everything to make things simpler / easier to upgrade
Expand All @@ -29,7 +38,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer):

_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'integer': types.Integer,
'bigint': types.BigInteger,
Expand Down Expand Up @@ -80,6 +89,7 @@ class PrestoDialect(default.DefaultDialect):
supports_multivalues_insert = True
supports_unicode_statements = True
supports_unicode_binds = True
supports_statement_cache = False
mdeshmu marked this conversation as resolved.
Show resolved Hide resolved
returns_unicode_strings = True
description_encoding = None
supports_native_boolean = True
Expand All @@ -88,6 +98,10 @@ class PrestoDialect(default.DefaultDialect):
@classmethod
def dbapi(cls):
return presto

@classmethod
def import_dbapi(cls):
return presto

def create_connect_args(self, url):
db_parts = (url.database or 'hive').split('/')
Expand All @@ -108,14 +122,14 @@ def create_connect_args(self, url):
return [], kwargs

def get_schema_names(self, connection, **kw):
return [row.Schema for row in connection.execute('SHOW SCHEMAS')]
return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))]

def _get_table_columns(self, connection, table_name, schema):
full_table = self.identifier_preparer.quote_identifier(table_name)
if schema:
full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table
try:
return connection.execute('SHOW COLUMNS FROM {}'.format(full_table))
return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table)))
except (presto.DatabaseError, exc.DatabaseError) as e:
# Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which
# it successfully does in the Hive version. The difference with Presto is that this
Expand All @@ -134,7 +148,7 @@ def _get_table_columns(self, connection, table_name, schema):
else:
raise

def has_table(self, connection, table_name, schema=None):
def has_table(self, connection, table_name, schema=None, **kw):
try:
self._get_table_columns(connection, table_name, schema)
return True
Expand Down Expand Up @@ -176,6 +190,8 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
# - a boolean column named "Partition Key"
# - a string in the "Comment" column
# - a string in the "Extra" column
if sqlalchemy_version >= 1.4:
row = row._mapping
mdeshmu marked this conversation as resolved.
Show resolved Hide resolved
is_partition_key = (
(part_key in row and row[part_key])
or row['Comment'].startswith(part_key)
Expand All @@ -192,7 +208,7 @@ def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
if schema:
query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
return [row.Table for row in connection.execute(query)]
return [row.Table for row in connection.execute(text(query))]

def do_rollback(self, dbapi_connection):
# No transactions for Presto
Expand Down
15 changes: 13 additions & 2 deletions pyhive/sqlalchemy_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
try:
from sqlalchemy.databases import mysql
mysql_tinyinteger = mysql.MSTinyInteger
except ImportError:
# Required for SQLAlchemy>=2.0
from sqlalchemy.dialects import mysql
mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
Expand All @@ -28,7 +34,7 @@ class TrinoIdentifierPreparer(PrestoIdentifierPreparer):

_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'integer': types.Integer,
'bigint': types.BigInteger,
Expand Down Expand Up @@ -67,7 +73,12 @@ def visit_TEXT(self, type_, **kw):

class TrinoDialect(PrestoDialect):
name = 'trino'
supports_statement_cache = False

@classmethod
def dbapi(cls):
return trino

@classmethod
def import_dbapi(cls):
return trino
Loading