Skip to content

Commit

Permalink
minimal changes for sqlalchemy 2.0 support
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeshmu committed Jun 28, 2023
1 parent 4367cc5 commit 838727e
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 94 deletions.
18 changes: 17 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ First install this package to register it with SQLAlchemy (see ``setup.py``).
# Presto
engine = create_engine('presto://localhost:8080/hive/default')
# Trino
engine = create_engine('trino://localhost:8080/hive/default')
engine = create_engine('trino+pyhive://localhost:8080/hive/default')
# Hive
engine = create_engine('hive://localhost:10000/default')
# SQLAlchemy < 2.0
logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
print select([func.count('*')], from_obj=logs).scalar()
Expand All @@ -82,6 +84,20 @@ First install this package to register it with SQLAlchemy (see ``setup.py``).
logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
print select([func.count('*')], from_obj=logs).scalar()
# SQLAlchemy >= 2.0
metadata_obj = MetaData()
books = Table("books", metadata_obj, Column("id", Integer), Column("title", String), Column("primary_author", String))
metadata_obj.create_all(engine)
inspector = inspect(engine)
inspector.get_columns('books')
with engine.connect() as con:
data = [{ "id": 1, "title": "The Hobbit", "primary_author": "Tolkien" },
{ "id": 2, "title": "The Silmarillion", "primary_author": "Tolkien" }]
con.execute(books.insert(), data[0])
result = con.execute(text("select * from books"))
print(result.fetchall())
Note: query generation functionality is not exhaustive or fully tested, but there should be no
problem with raw SQL.

Expand Down
30 changes: 23 additions & 7 deletions pyhive/sqlalchemy_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,22 @@

import re
from sqlalchemy import exc
from sqlalchemy import processors
from sqlalchemy.sql import text
try:
from sqlalchemy import processors
except ImportError:
# Required for SQLAlchemy>=2.0
from sqlalchemy.engine import processors
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
try:
from sqlalchemy.databases import mysql
mysql_tinyinteger = mysql.MSTinyInteger
except ImportError:
# Required for SQLAlchemy>2.0
from sqlalchemy.dialects import mysql
mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
Expand Down Expand Up @@ -121,7 +132,7 @@ def __init__(self, dialect):

_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'int': types.Integer,
'bigint': types.BigInteger,
Expand Down Expand Up @@ -247,10 +258,15 @@ class HiveDialect(default.DefaultDialect):
supports_multivalues_insert = True
type_compiler = HiveTypeCompiler
supports_sane_rowcount = False
supports_statement_cache = False

@classmethod
def dbapi(cls):
return hive

@classmethod
def import_dbapi(cls):
return hive

def create_connect_args(self, url):
kwargs = {
Expand All @@ -265,7 +281,7 @@ def create_connect_args(self, url):

def get_schema_names(self, connection, **kw):
# Equivalent to SHOW DATABASES
return [row[0] for row in connection.execute('SHOW SCHEMAS')]
return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))]

def get_view_names(self, connection, schema=None, **kw):
# Hive does not provide functionality to query tableType
Expand All @@ -280,7 +296,7 @@ def _get_table_columns(self, connection, table_name, schema):
# Using DESCRIBE works but is uglier.
try:
# This needs the table name to be unescaped (no backticks).
rows = connection.execute('DESCRIBE {}'.format(full_table)).fetchall()
rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall()
except exc.OperationalError as e:
# Does the table exist?
regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}'
Expand All @@ -296,7 +312,7 @@ def _get_table_columns(self, connection, table_name, schema):
raise exc.NoSuchTableError(full_table)
return rows

def has_table(self, connection, table_name, schema=None):
def has_table(self, connection, table_name, schema=None, **kw):
try:
self._get_table_columns(connection, table_name, schema)
return True
Expand Down Expand Up @@ -361,7 +377,7 @@ def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
if schema:
query += ' IN ' + self.identifier_preparer.quote_identifier(schema)
return [row[0] for row in connection.execute(query)]
return [row[0] for row in connection.execute(text(query))]

def do_rollback(self, dbapi_connection):
# No transactions for Hive
Expand Down
28 changes: 22 additions & 6 deletions pyhive/sqlalchemy_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,27 @@
from __future__ import unicode_literals

import re
import sqlalchemy
from sqlalchemy import exc
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
from sqlalchemy.sql import text
try:
from sqlalchemy.databases import mysql
mysql_tinyinteger = mysql.MSTinyInteger
except ImportError:
# Required for SQLAlchemy>=2.0
from sqlalchemy.dialects import mysql
mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler

from pyhive import presto
from pyhive.common import UniversalSet

sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))

class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
# Just quote everything to make things simpler / easier to upgrade
Expand All @@ -29,7 +38,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer):

_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'integer': types.Integer,
'bigint': types.BigInteger,
Expand Down Expand Up @@ -80,6 +89,7 @@ class PrestoDialect(default.DefaultDialect):
supports_multivalues_insert = True
supports_unicode_statements = True
supports_unicode_binds = True
supports_statement_cache = False
returns_unicode_strings = True
description_encoding = None
supports_native_boolean = True
Expand All @@ -88,6 +98,10 @@ class PrestoDialect(default.DefaultDialect):
@classmethod
def dbapi(cls):
return presto

@classmethod
def import_dbapi(cls):
return presto

def create_connect_args(self, url):
db_parts = (url.database or 'hive').split('/')
Expand All @@ -108,14 +122,14 @@ def create_connect_args(self, url):
return [], kwargs

def get_schema_names(self, connection, **kw):
return [row.Schema for row in connection.execute('SHOW SCHEMAS')]
return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))]

def _get_table_columns(self, connection, table_name, schema):
full_table = self.identifier_preparer.quote_identifier(table_name)
if schema:
full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table
try:
return connection.execute('SHOW COLUMNS FROM {}'.format(full_table))
return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table)))
except (presto.DatabaseError, exc.DatabaseError) as e:
# Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which
# it successfully does in the Hive version. The difference with Presto is that this
Expand All @@ -134,7 +148,7 @@ def _get_table_columns(self, connection, table_name, schema):
else:
raise

def has_table(self, connection, table_name, schema=None):
def has_table(self, connection, table_name, schema=None, **kw):
try:
self._get_table_columns(connection, table_name, schema)
return True
Expand Down Expand Up @@ -176,6 +190,8 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
# - a boolean column named "Partition Key"
# - a string in the "Comment" column
# - a string in the "Extra" column
if sqlalchemy_version >= 1.4:
row = row._mapping
is_partition_key = (
(part_key in row and row[part_key])
or row['Comment'].startswith(part_key)
Expand All @@ -192,7 +208,7 @@ def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
if schema:
query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
return [row.Table for row in connection.execute(query)]
return [row.Table for row in connection.execute(text(query))]

def do_rollback(self, dbapi_connection):
# No transactions for Presto
Expand Down
15 changes: 13 additions & 2 deletions pyhive/sqlalchemy_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
try:
from sqlalchemy.databases import mysql
mysql_tinyinteger = mysql.MSTinyInteger
except ImportError:
# Required for SQLAlchemy>=2.0
from sqlalchemy.dialects import mysql
mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
Expand All @@ -28,7 +34,7 @@ class TrinoIdentifierPreparer(PrestoIdentifierPreparer):

_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'integer': types.Integer,
'bigint': types.BigInteger,
Expand Down Expand Up @@ -67,7 +73,12 @@ def visit_TEXT(self, type_, **kw):

class TrinoDialect(PrestoDialect):
name = 'trino'
supports_statement_cache = False

@classmethod
def dbapi(cls):
return trino

@classmethod
def import_dbapi(cls):
return trino
Loading

0 comments on commit 838727e

Please sign in to comment.