Skip to content

Commit

Permalink
Add support for Trino (dropbox#381)
Browse files Browse the repository at this point in the history
1. Inherit from presto
2. Add travis test script
3. Add test cases
  • Loading branch information
wgzhao authored Mar 17, 2021
1 parent 1548ecc commit d6e7140
Show file tree
Hide file tree
Showing 12 changed files with 365 additions and 6 deletions.
10 changes: 5 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ matrix:
# https://docs.python.org/devguide/#status-of-python-branches
# One build pulls latest versions dynamically
- python: 3.6
env: CDH=cdh5 CDH_VERSION=5 PRESTO=RELEASE SQLALCHEMY=sqlalchemy>=1.3.0
env: CDH=cdh5 CDH_VERSION=5 PRESTO=RELEASE TRINO=RELEASE SQLALCHEMY=sqlalchemy>=1.3.0
- python: 3.6
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy>=1.3.0
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0
- python: 3.5
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy>=1.3.0
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0
- python: 3.4
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy>=1.3.0
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0
- python: 2.7
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy>=1.3.0
env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0
install:
- ./scripts/travis-install.sh
- pip install codecov
Expand Down
12 changes: 11 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ DB-API
------
.. code-block:: python
from pyhive import presto # or import hive
from pyhive import presto # or import hive or import trino
cursor = presto.connect('localhost').cursor()
cursor.execute('SELECT * FROM my_awesome_data LIMIT 10')
print cursor.fetchone()
Expand Down Expand Up @@ -63,6 +63,8 @@ First install this package to register it with SQLAlchemy (see ``setup.py``).
from sqlalchemy.schema import *
# Presto
engine = create_engine('presto://localhost:8080/hive/default')
# Trino
engine = create_engine('trino://localhost:8080/hive/default')
# Hive
engine = create_engine('hive://localhost:10000/default')
logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
Expand All @@ -79,12 +81,18 @@ Passing session configuration
# DB-API
hive.connect('localhost', configuration={'hive.exec.reducers.max': '123'})
presto.connect('localhost', session_props={'query_max_run_time': '1234m'})
trino.connect('localhost', session_props={'query_max_run_time': '1234m'})
# SQLAlchemy
create_engine(
'presto://user@host:443/hive',
connect_args={'protocol': 'https',
'session_props': {'query_max_run_time': '1234m'}}
)
create_engine(
'trino://user@host:443/hive',
connect_args={'protocol': 'https',
'session_props': {'query_max_run_time': '1234m'}}
)
create_engine(
'hive://user@host:10000/database',
connect_args={'configuration': {'hive.exec.reducers.max': '123'}},
Expand All @@ -102,11 +110,13 @@ Install using

- ``pip install 'pyhive[hive]'`` for the Hive interface and
- ``pip install 'pyhive[presto]'`` for the Presto interface.
- ``pip install 'pyhive[trino]'`` for the Trino interface

PyHive works with

- Python 2.7 / Python 3
- For Presto: Presto install
- For Trino: Trino install
- For Hive: `HiveServer2 <https://cwiki.apache.org/confluence/display/Hive/Setting+up+HiveServer2>`_ daemon

Changelog
Expand Down
73 changes: 73 additions & 0 deletions pyhive/sqlalchemy_trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Integration between SQLAlchemy and Trino.
Some code based on
https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py
which is released under the MIT license.
"""

from __future__ import absolute_import
from __future__ import unicode_literals

import re
from sqlalchemy import exc
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler

from pyhive import trino
from pyhive.common import UniversalSet
from pyhive.sqlalchemy_presto import PrestoDialect, PrestoCompiler, PrestoIdentifierPreparer

class TrinoIdentifierPreparer(PrestoIdentifierPreparer):
pass


_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'smallint': types.SmallInteger,
'integer': types.Integer,
'bigint': types.BigInteger,
'real': types.Float,
'double': types.Float,
'varchar': types.String,
'timestamp': types.TIMESTAMP,
'date': types.DATE,
'varbinary': types.VARBINARY,
}


class TrinoCompiler(PrestoCompiler):
pass


class TrinoTypeCompiler(PrestoCompiler):
def visit_CLOB(self, type_, **kw):
raise ValueError("Trino does not support the CLOB column type.")

def visit_NCLOB(self, type_, **kw):
raise ValueError("Trino does not support the NCLOB column type.")

def visit_DATETIME(self, type_, **kw):
raise ValueError("Trino does not support the DATETIME column type.")

def visit_FLOAT(self, type_, **kw):
return 'DOUBLE'

def visit_TEXT(self, type_, **kw):
if type_.length:
return 'VARCHAR({:d})'.format(type_.length)
else:
return 'VARCHAR'


class TrinoDialect(PrestoDialect):
name = 'trino'

@classmethod
def dbapi(cls):
return trino
96 changes: 96 additions & 0 deletions pyhive/tests/test_trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Trino integration tests.
These rely on having a Trino+Hadoop cluster set up.
They also require a tables created by make_test_tables.sh.
"""

from __future__ import absolute_import
from __future__ import unicode_literals

import contextlib
import os
import requests

from pyhive import exc
from pyhive import trino
from pyhive.tests.dbapi_test_case import DBAPITestCase
from pyhive.tests.dbapi_test_case import with_cursor
from pyhive.tests.test_presto import TestPresto
import mock
import unittest
import datetime

_HOST = 'localhost'
_PORT = '18080'


class TestTrino(TestPresto):
__test__ = True

def connect(self):
return trino.connect(host=_HOST, port=_PORT, source=self.id())

def test_bad_protocol(self):
self.assertRaisesRegexp(ValueError, 'Protocol must be',
lambda: trino.connect('localhost', protocol='nonsense').cursor())

def test_escape_args(self):
escaper = trino.TrinoParamEscaper()

self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)),
("date '2020-04-17'",))
self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)),
("timestamp '2020-04-17 12:00:00.123'",))

@with_cursor
def test_description(self, cursor):
cursor.execute('SELECT 1 AS foobar FROM one_row')
self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)])
self.assertIsNotNone(cursor.last_query_id)

@with_cursor
def test_complex(self, cursor):
cursor.execute('SELECT * FROM one_row_complex')
# TODO Trino drops the union field

tinyint_type = 'tinyint'
smallint_type = 'smallint'
float_type = 'real'
self.assertEqual(cursor.description, [
('boolean', 'boolean', None, None, None, None, True),
('tinyint', tinyint_type, None, None, None, None, True),
('smallint', smallint_type, None, None, None, None, True),
('int', 'integer', None, None, None, None, True),
('bigint', 'bigint', None, None, None, None, True),
('float', float_type, None, None, None, None, True),
('double', 'double', None, None, None, None, True),
('string', 'varchar', None, None, None, None, True),
('timestamp', 'timestamp', None, None, None, None, True),
('binary', 'varbinary', None, None, None, None, True),
('array', 'array(integer)', None, None, None, None, True),
('map', 'map(integer,integer)', None, None, None, None, True),
('struct', 'row(a integer,b integer)', None, None, None, None, True),
# ('union', 'varchar', None, None, None, None, True),
('decimal', 'decimal(10,1)', None, None, None, None, True),
])
rows = cursor.fetchall()
expected = [(
True,
127,
32767,
2147483647,
9223372036854775807,
0.5,
0.25,
'a string',
'1970-01-01 00:00:00.000',
b'123',
[1, 2],
{"1": 2, "3": 4}, # Trino converts all keys to strings so that they're valid JSON
[1, 2], # struct is returned as a list of elements
# '{0:1}',
'0.1',
)]
self.assertEqual(rows, expected)
# catch unicode/str
self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0])))
144 changes: 144 additions & 0 deletions pyhive/trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""DB-API implementation backed by Trino
See http://www.python.org/dev/peps/pep-0249/
Many docstrings in this file are based on the PEP, which is in the public domain.
"""

from __future__ import absolute_import
from __future__ import unicode_literals

import logging

import requests

# Make all exceptions visible in this module per DB-API
from pyhive.common import DBAPITypeObject
from pyhive.exc import * # noqa
from pyhive.presto import Connection as PrestoConnection, Cursor as PrestoCursor, PrestoParamEscaper

try: # Python 3
import urllib.parse as urlparse
except ImportError: # Python 2
import urlparse

# PEP 249 module globals
apilevel = '2.0'
threadsafety = 2 # Threads may share the module and connections.
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s

_logger = logging.getLogger(__name__)


class TrinoParamEscaper(PrestoParamEscaper):
pass


_escaper = TrinoParamEscaper()


def connect(*args, **kwargs):
"""Constructor for creating a connection to the database. See class :py:class:`Connection` for
arguments.
:returns: a :py:class:`Connection` object.
"""
return Connection(*args, **kwargs)


class Connection(PrestoConnection):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def cursor(self):
"""Return a new :py:class:`Cursor` object using the connection."""
return Cursor(*self._args, **self._kwargs)


class Cursor(PrestoCursor):
"""These objects represent a database cursor, which is used to manage the context of a fetch
operation.
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
visible by other cursors or connections.
"""

def execute(self, operation, parameters=None):
"""Prepare and execute a database operation (query or command).
Return values are not defined.
"""
headers = {
'X-Trino-Catalog': self._catalog,
'X-Trino-Schema': self._schema,
'X-Trino-Source': self._source,
'X-Trino-User': self._username,
}

if self._session_props:
headers['X-Trino-Session'] = ','.join(
'{}={}'.format(propname, propval)
for propname, propval in self._session_props.items()
)

# Prepare statement
if parameters is None:
sql = operation
else:
sql = operation % _escaper.escape_args(parameters)

self._reset_state()

self._state = self._STATE_RUNNING
url = urlparse.urlunparse((
self._protocol,
'{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None))
_logger.info('%s', sql)
_logger.debug("Headers: %s", headers)
response = self._requests_session.post(
url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs)
self._process_response(response)

def _process_response(self, response):
"""Given the JSON response from Trino's REST API, update the internal state with the next
URI and any data from the response
"""
# TODO handle HTTP 503
if response.status_code != requests.codes.ok:
fmt = "Unexpected status code {}\n{}"
raise OperationalError(fmt.format(response.status_code, response.content))

response_json = response.json()
_logger.debug("Got response %s", response_json)
assert self._state == self._STATE_RUNNING, "Should be running if processing response"
self._nextUri = response_json.get('nextUri')
self._columns = response_json.get('columns')
if 'id' in response_json:
self.last_query_id = response_json['id']
if 'X-Trino-Clear-Session' in response.headers:
propname = response.headers['X-Trino-Clear-Session']
self._session_props.pop(propname, None)
if 'X-Trino-Set-Session' in response.headers:
propname, propval = response.headers['X-Trino-Set-Session'].split('=', 1)
self._session_props[propname] = propval
if 'data' in response_json:
assert self._columns
new_data = response_json['data']
self._decode_binary(new_data)
self._data += map(tuple, new_data)
if 'nextUri' not in response_json:
self._state = self._STATE_FINISHED
if 'error' in response_json:
raise DatabaseError(response_json['error'])


#
# Type Objects and Constructors
#


# See types in trino-main/src/main/java/com/facebook/trino/tuple/TupleInfo.java
FIXED_INT_64 = DBAPITypeObject(['bigint'])
VARIABLE_BINARY = DBAPITypeObject(['varchar'])
DOUBLE = DBAPITypeObject(['double'])
BOOLEAN = DBAPITypeObject(['boolean'])
2 changes: 2 additions & 0 deletions scripts/travis-conf/trino/catalog/hive.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
connector.name=hive-hadoop2
hive.metastore.uri=thrift://localhost:9083
7 changes: 7 additions & 0 deletions scripts/travis-conf/trino/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
coordinator=true
node-scheduler.include-coordinator=true
http-server.http.port=18080
query.max-memory=100MB
query.max-memory-per-node=100MB
discovery-server.enabled=true
discovery.uri=http://localhost:18080
Empty file.
Loading

0 comments on commit d6e7140

Please sign in to comment.