Skip to content

Commit

Permalink
minimal changes for sqlalchemy 2.0 support
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeshmu committed Jun 23, 2023
1 parent 0bd6f5b commit d1bf344
Show file tree
Hide file tree
Showing 13 changed files with 656 additions and 122 deletions.
16 changes: 16 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ First install this package to register it with SQLAlchemy (see ``setup.py``).
engine = create_engine('trino://localhost:8080/hive/default')
# Hive
engine = create_engine('hive://localhost:10000/default')
# SQLAlchemy < 2.0
logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
print select([func.count('*')], from_obj=logs).scalar()
Expand All @@ -82,6 +84,20 @@ First install this package to register it with SQLAlchemy (see ``setup.py``).
logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
print select([func.count('*')], from_obj=logs).scalar()
# SQLAlchemy >= 2.0
metadata_obj = MetaData()
books = Table("books", metadata_obj, Column("id", Integer), Column("title", String), Column("primary_author", String))
metadata_obj.create_all(engine)
inspector = inspect(engine)
inspector.get_columns('books')
with engine.connect() as con:
data = [{ "id": 1, "title": "The Hobbit", "primary_author": "Tolkien" },
{ "id": 2, "title": "The Silmarillion", "primary_author": "Tolkien" }]
con.execute(books.insert(), data[0])
result = con.execute(text("select * from books"))
print(result.fetchall())
Note: query generation functionality is not exhaustive or fully tested, but there should be no
problem with raw SQL.

Expand Down
2 changes: 2 additions & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pytest-timeout==1.2.0
requests>=1.0.0
requests_kerberos>=0.12.0
sasl>=0.2.1
pure-sasl>=0.6.2
kerberos>=1.3.0
thrift>=0.10.0
#thrift_sasl>=0.1.0
git+https://github.com/cloudera/thrift_sasl # Using master branch in order to get Python 3 SASL patches
56 changes: 41 additions & 15 deletions pyhive/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,45 @@
}


def get_sasl_client(host, sasl_auth, service=None, username=None, password=None):
import sasl
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)

if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', service)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise ValueError("sasl_auth only supports GSSAPI and PLAIN")

sasl_client.init()
return sasl_client


def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None):
from pyhive.sasl_compat import PureSASLClient

if sasl_auth == 'GSSAPI':
sasl_kwargs = {'service': service}
elif sasl_auth == 'PLAIN':
sasl_kwargs = {'username': username, 'password': password}
else:
raise ValueError("sasl_auth only supports GSSAPI and PLAIN")

return PureSASLClient(host=host, **sasl_kwargs)


def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None):
try:
return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)
# The sasl library is available
except ImportError:
# Fallback to pure-sasl library
return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)


def _parse_timestamp(value):
if value:
match = _TIMESTAMP_PATTERN.match(value)
Expand Down Expand Up @@ -200,7 +239,6 @@ def __init__(
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'):
# Defer import so package dependency is optional
import sasl
import thrift_sasl

if auth == 'KERBEROS':
Expand All @@ -211,20 +249,8 @@ def __init__(
if password is None:
# Password doesn't matter in NONE mode, just needs to be nonempty.
password = 'x'

def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', kerberos_service_name)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise AssertionError
sasl_client.init()
return sasl_client
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)

self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket)
else:
# All HS2 config options:
# https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration
Expand Down
56 changes: 56 additions & 0 deletions pyhive/sasl_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py
# which uses Apache-2.0 license as of 21 May 2023.
# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl
# via PR https://github.com/cloudera/impyla/pull/179
# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34
# but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82
# Hence this code is required for the fallback to work.


from puresasl.client import SASLClient, SASLError
from contextlib import contextmanager

@contextmanager
def error_catcher(self, Exc = Exception):
try:
self.error = None
yield
except Exc as e:
self.error = str(e)


class PureSASLClient(SASLClient):
def __init__(self, *args, **kwargs):
self.error = None
super(PureSASLClient, self).__init__(*args, **kwargs)

def start(self, mechanism):
with error_catcher(self, SASLError):
if isinstance(mechanism, list):
self.choose_mechanism(mechanism)
else:
self.choose_mechanism([mechanism])
return True, self.mechanism, self.process()
# else
return False, mechanism, None

def encode(self, incoming):
with error_catcher(self):
return True, self.unwrap(incoming)
# else
return False, None

def decode(self, outgoing):
with error_catcher(self):
return True, self.wrap(outgoing)
# else
return False, None

def step(self, challenge=None):
with error_catcher(self):
return True, self.process(challenge)
# else
return False, None

def getError(self):
return self.error
30 changes: 23 additions & 7 deletions pyhive/sqlalchemy_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,22 @@

import re
from sqlalchemy import exc
from sqlalchemy import processors
from sqlalchemy.sql import text
try:
from sqlalchemy import processors
except ImportError:
# Required for SQLAlchemy>=2.0
from sqlalchemy.engine import processors
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
try:
from sqlalchemy.databases import mysql
mysql_tinyinteger = mysql.MSTinyInteger
except ImportError:
# Required for SQLAlchemy>2.0
from sqlalchemy.dialects import mysql
mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
Expand Down Expand Up @@ -121,7 +132,7 @@ def __init__(self, dialect):

_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'int': types.Integer,
'bigint': types.BigInteger,
Expand Down Expand Up @@ -247,10 +258,15 @@ class HiveDialect(default.DefaultDialect):
supports_multivalues_insert = True
type_compiler = HiveTypeCompiler
supports_sane_rowcount = False
supports_statement_cache = False

@classmethod
def dbapi(cls):
return hive

@classmethod
def import_dbapi(cls):
return hive

def create_connect_args(self, url):
kwargs = {
Expand All @@ -265,7 +281,7 @@ def create_connect_args(self, url):

def get_schema_names(self, connection, **kw):
# Equivalent to SHOW DATABASES
return [row[0] for row in connection.execute('SHOW SCHEMAS')]
return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))]

def get_view_names(self, connection, schema=None, **kw):
# Hive does not provide functionality to query tableType
Expand All @@ -280,7 +296,7 @@ def _get_table_columns(self, connection, table_name, schema):
# Using DESCRIBE works but is uglier.
try:
# This needs the table name to be unescaped (no backticks).
rows = connection.execute('DESCRIBE {}'.format(full_table)).fetchall()
rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall()
except exc.OperationalError as e:
# Does the table exist?
regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}'
Expand All @@ -296,7 +312,7 @@ def _get_table_columns(self, connection, table_name, schema):
raise exc.NoSuchTableError(full_table)
return rows

def has_table(self, connection, table_name, schema=None):
def has_table(self, connection, table_name, schema=None, **kw):
try:
self._get_table_columns(connection, table_name, schema)
return True
Expand Down Expand Up @@ -361,7 +377,7 @@ def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
if schema:
query += ' IN ' + self.identifier_preparer.quote_identifier(schema)
return [row[0] for row in connection.execute(query)]
return [row[0] for row in connection.execute(text(query))]

def do_rollback(self, dbapi_connection):
# No transactions for Hive
Expand Down
28 changes: 22 additions & 6 deletions pyhive/sqlalchemy_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,27 @@
from __future__ import unicode_literals

import re
import sqlalchemy
from sqlalchemy import exc
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
from sqlalchemy.sql import text
try:
from sqlalchemy.databases import mysql
mysql_tinyinteger = mysql.MSTinyInteger
except ImportError:
# Required for SQLAlchemy>=2.0
from sqlalchemy.dialects import mysql
mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler

from pyhive import presto
from pyhive.common import UniversalSet

sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))

class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
# Just quote everything to make things simpler / easier to upgrade
Expand All @@ -29,7 +38,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer):

_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'integer': types.Integer,
'bigint': types.BigInteger,
Expand Down Expand Up @@ -80,6 +89,7 @@ class PrestoDialect(default.DefaultDialect):
supports_multivalues_insert = True
supports_unicode_statements = True
supports_unicode_binds = True
supports_statement_cache = False
returns_unicode_strings = True
description_encoding = None
supports_native_boolean = True
Expand All @@ -88,6 +98,10 @@ class PrestoDialect(default.DefaultDialect):
@classmethod
def dbapi(cls):
return presto

@classmethod
def import_dbapi(cls):
return presto

def create_connect_args(self, url):
db_parts = (url.database or 'hive').split('/')
Expand All @@ -108,14 +122,14 @@ def create_connect_args(self, url):
return [], kwargs

def get_schema_names(self, connection, **kw):
return [row.Schema for row in connection.execute('SHOW SCHEMAS')]
return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))]

def _get_table_columns(self, connection, table_name, schema):
full_table = self.identifier_preparer.quote_identifier(table_name)
if schema:
full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table
try:
return connection.execute('SHOW COLUMNS FROM {}'.format(full_table))
return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table)))
except (presto.DatabaseError, exc.DatabaseError) as e:
# Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which
# it successfully does in the Hive version. The difference with Presto is that this
Expand All @@ -134,7 +148,7 @@ def _get_table_columns(self, connection, table_name, schema):
else:
raise

def has_table(self, connection, table_name, schema=None):
def has_table(self, connection, table_name, schema=None, **kw):
try:
self._get_table_columns(connection, table_name, schema)
return True
Expand Down Expand Up @@ -176,6 +190,8 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
# - a boolean column named "Partition Key"
# - a string in the "Comment" column
# - a string in the "Extra" column
if sqlalchemy_version >= 1.4:
row = row._mapping
is_partition_key = (
(part_key in row and row[part_key])
or row['Comment'].startswith(part_key)
Expand All @@ -192,7 +208,7 @@ def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
if schema:
query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
return [row.Table for row in connection.execute(query)]
return [row.Table for row in connection.execute(text(query))]

def do_rollback(self, dbapi_connection):
# No transactions for Presto
Expand Down
14 changes: 12 additions & 2 deletions pyhive/sqlalchemy_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
try:
from sqlalchemy.databases import mysql
mysql_tinyinteger = mysql.MSTinyInteger
except ImportError:
# Required for SQLAlchemy>=2.0
from sqlalchemy.dialects import mysql
mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
Expand All @@ -28,7 +34,7 @@ class TrinoIdentifierPreparer(PrestoIdentifierPreparer):

_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'integer': types.Integer,
'bigint': types.BigInteger,
Expand Down Expand Up @@ -71,3 +77,7 @@ class TrinoDialect(PrestoDialect):
@classmethod
def dbapi(cls):
return trino

@classmethod
def import_dbapi(cls):
return trino
Loading

0 comments on commit d1bf344

Please sign in to comment.