Skip to content

Commit

Permalink
add compatibility with sqlalchemy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeshmu committed Jul 8, 2023
1 parent 838727e commit 7d533b7
Showing 1 changed file with 36 additions and 50 deletions.
86 changes: 36 additions & 50 deletions pyhive/tests/sqlalchemy_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ def wrapped_fn(self, *args, **kwargs):
engine.dispose()
return wrapped_fn

def reflect_table(engine, connection, table, include_columns, exclude_columns, resolve_fks):
if sqlalchemy_version >= 1.4:
insp = sqlalchemy.inspect(engine)
insp.reflect_table(
table,
include_columns=include_columns,
exclude_columns=exclude_columns,
resolve_fks=resolve_fks,
)

else:
engine.dialect.reflecttable(
connection, table, include_columns=include_columns,
exclude_columns=exclude_columns, resolve_fks=resolve_fks)


class SqlAlchemyTestCase(with_metaclass(abc.ABCMeta, object)):
@with_engine_connection
Expand Down Expand Up @@ -66,19 +81,8 @@ def test_reflect_include_columns(self, engine, connection):
"""When passed include_columns, reflecttable should filter out other columns"""

one_row_complex = Table('one_row_complex', MetaData())

if sqlalchemy_version == 1.3:
engine.dialect.reflecttable(
connection, one_row_complex, include_columns=['int'],
exclude_columns=[], resolve_fks=True)
else:
insp = sqlalchemy.inspect(engine)
insp.reflect_table(
one_row_complex,
include_columns=['int'],
exclude_columns=[],
resolve_fks=True,
)
reflect_table(engine, connection, one_row_complex, include_columns=['int'],
exclude_columns=[], resolve_fks=True)

self.assertEqual(len(one_row_complex.c), 1)
self.assertIsNotNone(one_row_complex.c.int)
Expand All @@ -99,38 +103,16 @@ def test_reflect_partitions(self, engine, connection):
self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)}))

many_rows = Table('many_rows', MetaData())

if sqlalchemy_version == 1.3:
engine.dialect.reflecttable(
connection, many_rows, include_columns=['a'],
exclude_columns=[], resolve_fks=True)
else:
insp = sqlalchemy.inspect(engine)
insp.reflect_table(
many_rows,
include_columns=['a'],
exclude_columns=[],
resolve_fks=True,
)
reflect_table(engine, connection, many_rows, include_columns=['a'],
exclude_columns=[], resolve_fks=True)

self.assertEqual(len(many_rows.c), 1)
self.assertFalse(many_rows.c.a.index)
self.assertFalse(many_rows.indexes)

many_rows = Table('many_rows', MetaData())

if sqlalchemy_version == 1.3:
engine.dialect.reflecttable(
connection, many_rows, include_columns=['b'],
exclude_columns=[], resolve_fks=True)
else:
insp = sqlalchemy.inspect(engine)
insp.reflect_table(
many_rows,
include_columns=['b'],
exclude_columns=[],
resolve_fks=True,
)
reflect_table(engine, connection, many_rows, include_columns=['b'],
exclude_columns=[], resolve_fks=True)

self.assertEqual(len(many_rows.c), 1)
self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)}))
Expand All @@ -140,12 +122,14 @@ def test_unicode(self, engine, connection):
"""Verify that unicode strings make it through SQLAlchemy and the backend"""
unicode_str = "中文"
one_row = Table('one_row', MetaData())
if sqlalchemy_version == 1.3:
returned_str = connection.execute(sqlalchemy.select([
expression.bindparam("好", unicode_str, type_=String())]).select_from(one_row)).scalar()
else:

if sqlalchemy_version >= 1.4:
returned_str = connection.execute(sqlalchemy.select(
expression.bindparam("好", unicode_str, type_=String())).select_from(one_row)).scalar()
else:
returned_str = connection.execute(sqlalchemy.select([
expression.bindparam("好", unicode_str, type_=String())]).select_from(one_row)).scalar()

self.assertEqual(returned_str, unicode_str)

@with_engine_connection
Expand All @@ -170,19 +154,21 @@ def test_get_table_names(self, engine, connection):

@with_engine_connection
def test_has_table(self, engine, connection):
if sqlalchemy_version == 1.3:
self.assertTrue(Table('one_row', MetaData(bind=engine)).exists())
self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists())
else:
if sqlalchemy_version >= 1.4:
insp = sqlalchemy.inspect(engine)
self.assertTrue(insp.has_table("one_row"))
self.assertFalse(insp.has_table("this_table_does_not_exist"))
else:
self.assertTrue(Table('one_row', MetaData(bind=engine)).exists())
self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists())

@with_engine_connection
def test_char_length(self, engine, connection):
one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine)
if sqlalchemy_version == 1.3:
result = connection.execute(sqlalchemy.select([sqlalchemy.func.char_length(one_row_complex.c.string)])).scalar()
else:

if sqlalchemy_version >= 1.4:
result = connection.execute(sqlalchemy.select(sqlalchemy.func.char_length(one_row_complex.c.string))).scalar()
else:
result = connection.execute(sqlalchemy.select([sqlalchemy.func.char_length(one_row_complex.c.string)])).scalar()

self.assertEqual(result, len('a string'))

0 comments on commit 7d533b7

Please sign in to comment.