From d6e7140d366cc37d01af0d8308e7739d090d73ef Mon Sep 17 00:00:00 2001 From: wgzhao Date: Thu, 18 Mar 2021 05:43:03 +0800 Subject: [PATCH] Add support for Trino (#381) 1. Inherit from presto 2. Add travis test script 3. Add test cases --- .travis.yml | 10 +- README.rst | 12 +- pyhive/sqlalchemy_trino.py | 73 +++++++++ pyhive/tests/test_trino.py | 96 ++++++++++++ pyhive/trino.py | 144 ++++++++++++++++++ .../travis-conf/trino/catalog/hive.properties | 2 + scripts/travis-conf/trino/config.properties | 7 + scripts/travis-conf/trino/jvm.config | 0 scripts/travis-conf/trino/node.properties | 3 + scripts/travis-install.sh | 21 +++ setup.cfg | 1 + setup.py | 2 + 12 files changed, 365 insertions(+), 6 deletions(-) create mode 100644 pyhive/sqlalchemy_trino.py create mode 100644 pyhive/tests/test_trino.py create mode 100644 pyhive/trino.py create mode 100644 scripts/travis-conf/trino/catalog/hive.properties create mode 100644 scripts/travis-conf/trino/config.properties create mode 100644 scripts/travis-conf/trino/jvm.config create mode 100644 scripts/travis-conf/trino/node.properties diff --git a/.travis.yml b/.travis.yml index 30814a5a..73b6c805 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,15 +6,15 @@ matrix: # https://docs.python.org/devguide/#status-of-python-branches # One build pulls latest versions dynamically - python: 3.6 - env: CDH=cdh5 CDH_VERSION=5 PRESTO=RELEASE SQLALCHEMY=sqlalchemy>=1.3.0 + env: CDH=cdh5 CDH_VERSION=5 PRESTO=RELEASE TRINO=RELEASE SQLALCHEMY=sqlalchemy>=1.3.0 - python: 3.6 - env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy>=1.3.0 + env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0 - python: 3.5 - env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy>=1.3.0 + env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0 - python: 3.4 - env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy>=1.3.0 + env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0 - python: 2.7 - env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy>=1.3.0 + env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0 install: - ./scripts/travis-install.sh - pip install codecov diff --git a/README.rst b/README.rst index 2ada488f..8903ce78 100644 --- a/README.rst +++ b/README.rst @@ -17,7 +17,7 @@ DB-API ------ .. code-block:: python - from pyhive import presto # or import hive + from pyhive import presto # or import hive or import trino cursor = presto.connect('localhost').cursor() cursor.execute('SELECT * FROM my_awesome_data LIMIT 10') print cursor.fetchone() @@ -63,6 +63,8 @@ First install this package to register it with SQLAlchemy (see ``setup.py``). from sqlalchemy.schema import * # Presto engine = create_engine('presto://localhost:8080/hive/default') + # Trino + engine = create_engine('trino://localhost:8080/hive/default') # Hive engine = create_engine('hive://localhost:10000/default') logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True) @@ -79,12 +81,18 @@ Passing session configuration # DB-API hive.connect('localhost', configuration={'hive.exec.reducers.max': '123'}) presto.connect('localhost', session_props={'query_max_run_time': '1234m'}) + trino.connect('localhost', session_props={'query_max_run_time': '1234m'}) # SQLAlchemy create_engine( 'presto://user@host:443/hive', connect_args={'protocol': 'https', 'session_props': {'query_max_run_time': '1234m'}} ) + create_engine( + 'trino://user@host:443/hive', + connect_args={'protocol': 'https', + 'session_props': {'query_max_run_time': '1234m'}} + ) create_engine( 'hive://user@host:10000/database', connect_args={'configuration': {'hive.exec.reducers.max': '123'}}, @@ -102,11 +110,13 @@ Install using - ``pip install 'pyhive[hive]'`` for the Hive interface and - ``pip install 'pyhive[presto]'`` for the Presto interface. +- ``pip install 'pyhive[trino]'`` for the Trino interface PyHive works with - Python 2.7 / Python 3 - For Presto: Presto install +- For Trino: Trino install - For Hive: `HiveServer2 `_ daemon Changelog diff --git a/pyhive/sqlalchemy_trino.py b/pyhive/sqlalchemy_trino.py new file mode 100644 index 00000000..4b2b3698 --- /dev/null +++ b/pyhive/sqlalchemy_trino.py @@ -0,0 +1,73 @@ +"""Integration between SQLAlchemy and Trino. + +Some code based on +https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py +which is released under the MIT license. +""" + +from __future__ import absolute_import +from __future__ import unicode_literals + +import re +from sqlalchemy import exc +from sqlalchemy import types +from sqlalchemy import util +# TODO shouldn't use mysql type +from sqlalchemy.databases import mysql +from sqlalchemy.engine import default +from sqlalchemy.sql import compiler +from sqlalchemy.sql.compiler import SQLCompiler + +from pyhive import trino +from pyhive.common import UniversalSet +from pyhive.sqlalchemy_presto import PrestoDialect, PrestoCompiler, PrestoIdentifierPreparer + +class TrinoIdentifierPreparer(PrestoIdentifierPreparer): + pass + + +_type_map = { + 'boolean': types.Boolean, + 'tinyint': mysql.MSTinyInteger, + 'smallint': types.SmallInteger, + 'integer': types.Integer, + 'bigint': types.BigInteger, + 'real': types.Float, + 'double': types.Float, + 'varchar': types.String, + 'timestamp': types.TIMESTAMP, + 'date': types.DATE, + 'varbinary': types.VARBINARY, +} + + +class TrinoCompiler(PrestoCompiler): + pass + + +class TrinoTypeCompiler(PrestoCompiler): + def visit_CLOB(self, type_, **kw): + raise ValueError("Trino does not support the CLOB column type.") + + def visit_NCLOB(self, type_, **kw): + raise ValueError("Trino does not support the NCLOB column type.") + + def visit_DATETIME(self, type_, **kw): + raise ValueError("Trino does not support the DATETIME column type.") + + def visit_FLOAT(self, type_, **kw): + return 'DOUBLE' + + def visit_TEXT(self, type_, **kw): + if type_.length: + return 'VARCHAR({:d})'.format(type_.length) + else: + return 'VARCHAR' + + +class TrinoDialect(PrestoDialect): + name = 'trino' + + @classmethod + def dbapi(cls): + return trino diff --git a/pyhive/tests/test_trino.py b/pyhive/tests/test_trino.py new file mode 100644 index 00000000..cdc8bb43 --- /dev/null +++ b/pyhive/tests/test_trino.py @@ -0,0 +1,96 @@ +"""Trino integration tests. + +These rely on having a Trino+Hadoop cluster set up. +They also require a tables created by make_test_tables.sh. +""" + +from __future__ import absolute_import +from __future__ import unicode_literals + +import contextlib +import os +import requests + +from pyhive import exc +from pyhive import trino +from pyhive.tests.dbapi_test_case import DBAPITestCase +from pyhive.tests.dbapi_test_case import with_cursor +from pyhive.tests.test_presto import TestPresto +import mock +import unittest +import datetime + +_HOST = 'localhost' +_PORT = '18080' + + +class TestTrino(TestPresto): + __test__ = True + + def connect(self): + return trino.connect(host=_HOST, port=_PORT, source=self.id()) + + def test_bad_protocol(self): + self.assertRaisesRegexp(ValueError, 'Protocol must be', + lambda: trino.connect('localhost', protocol='nonsense').cursor()) + + def test_escape_args(self): + escaper = trino.TrinoParamEscaper() + + self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)), + ("date '2020-04-17'",)) + self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), + ("timestamp '2020-04-17 12:00:00.123'",)) + + @with_cursor + def test_description(self, cursor): + cursor.execute('SELECT 1 AS foobar FROM one_row') + self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)]) + self.assertIsNotNone(cursor.last_query_id) + + @with_cursor + def test_complex(self, cursor): + cursor.execute('SELECT * FROM one_row_complex') + # TODO Trino drops the union field + + tinyint_type = 'tinyint' + smallint_type = 'smallint' + float_type = 'real' + self.assertEqual(cursor.description, [ + ('boolean', 'boolean', None, None, None, None, True), + ('tinyint', tinyint_type, None, None, None, None, True), + ('smallint', smallint_type, None, None, None, None, True), + ('int', 'integer', None, None, None, None, True), + ('bigint', 'bigint', None, None, None, None, True), + ('float', float_type, None, None, None, None, True), + ('double', 'double', None, None, None, None, True), + ('string', 'varchar', None, None, None, None, True), + ('timestamp', 'timestamp', None, None, None, None, True), + ('binary', 'varbinary', None, None, None, None, True), + ('array', 'array(integer)', None, None, None, None, True), + ('map', 'map(integer,integer)', None, None, None, None, True), + ('struct', 'row(a integer,b integer)', None, None, None, None, True), + # ('union', 'varchar', None, None, None, None, True), + ('decimal', 'decimal(10,1)', None, None, None, None, True), + ]) + rows = cursor.fetchall() + expected = [( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + 'a string', + '1970-01-01 00:00:00.000', + b'123', + [1, 2], + {"1": 2, "3": 4}, # Trino converts all keys to strings so that they're valid JSON + [1, 2], # struct is returned as a list of elements + # '{0:1}', + '0.1', + )] + self.assertEqual(rows, expected) + # catch unicode/str + self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) \ No newline at end of file diff --git a/pyhive/trino.py b/pyhive/trino.py new file mode 100644 index 00000000..e8a1aabd --- /dev/null +++ b/pyhive/trino.py @@ -0,0 +1,144 @@ +"""DB-API implementation backed by Trino + +See http://www.python.org/dev/peps/pep-0249/ + +Many docstrings in this file are based on the PEP, which is in the public domain. +""" + +from __future__ import absolute_import +from __future__ import unicode_literals + +import logging + +import requests + +# Make all exceptions visible in this module per DB-API +from pyhive.common import DBAPITypeObject +from pyhive.exc import * # noqa +from pyhive.presto import Connection as PrestoConnection, Cursor as PrestoCursor, PrestoParamEscaper + +try: # Python 3 + import urllib.parse as urlparse +except ImportError: # Python 2 + import urlparse + +# PEP 249 module globals +apilevel = '2.0' +threadsafety = 2 # Threads may share the module and connections. +paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s + +_logger = logging.getLogger(__name__) + + +class TrinoParamEscaper(PrestoParamEscaper): + pass + + +_escaper = TrinoParamEscaper() + + +def connect(*args, **kwargs): + """Constructor for creating a connection to the database. See class :py:class:`Connection` for + arguments. + + :returns: a :py:class:`Connection` object. + """ + return Connection(*args, **kwargs) + + +class Connection(PrestoConnection): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def cursor(self): + """Return a new :py:class:`Cursor` object using the connection.""" + return Cursor(*self._args, **self._kwargs) + + +class Cursor(PrestoCursor): + """These objects represent a database cursor, which is used to manage the context of a fetch + operation. + + Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately + visible by other cursors or connections. + """ + + def execute(self, operation, parameters=None): + """Prepare and execute a database operation (query or command). + + Return values are not defined. + """ + headers = { + 'X-Trino-Catalog': self._catalog, + 'X-Trino-Schema': self._schema, + 'X-Trino-Source': self._source, + 'X-Trino-User': self._username, + } + + if self._session_props: + headers['X-Trino-Session'] = ','.join( + '{}={}'.format(propname, propval) + for propname, propval in self._session_props.items() + ) + + # Prepare statement + if parameters is None: + sql = operation + else: + sql = operation % _escaper.escape_args(parameters) + + self._reset_state() + + self._state = self._STATE_RUNNING + url = urlparse.urlunparse(( + self._protocol, + '{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None)) + _logger.info('%s', sql) + _logger.debug("Headers: %s", headers) + response = self._requests_session.post( + url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs) + self._process_response(response) + + def _process_response(self, response): + """Given the JSON response from Trino's REST API, update the internal state with the next + URI and any data from the response + """ + # TODO handle HTTP 503 + if response.status_code != requests.codes.ok: + fmt = "Unexpected status code {}\n{}" + raise OperationalError(fmt.format(response.status_code, response.content)) + + response_json = response.json() + _logger.debug("Got response %s", response_json) + assert self._state == self._STATE_RUNNING, "Should be running if processing response" + self._nextUri = response_json.get('nextUri') + self._columns = response_json.get('columns') + if 'id' in response_json: + self.last_query_id = response_json['id'] + if 'X-Trino-Clear-Session' in response.headers: + propname = response.headers['X-Trino-Clear-Session'] + self._session_props.pop(propname, None) + if 'X-Trino-Set-Session' in response.headers: + propname, propval = response.headers['X-Trino-Set-Session'].split('=', 1) + self._session_props[propname] = propval + if 'data' in response_json: + assert self._columns + new_data = response_json['data'] + self._decode_binary(new_data) + self._data += map(tuple, new_data) + if 'nextUri' not in response_json: + self._state = self._STATE_FINISHED + if 'error' in response_json: + raise DatabaseError(response_json['error']) + + +# +# Type Objects and Constructors +# + + +# See types in trino-main/src/main/java/com/facebook/trino/tuple/TupleInfo.java +FIXED_INT_64 = DBAPITypeObject(['bigint']) +VARIABLE_BINARY = DBAPITypeObject(['varchar']) +DOUBLE = DBAPITypeObject(['double']) +BOOLEAN = DBAPITypeObject(['boolean']) diff --git a/scripts/travis-conf/trino/catalog/hive.properties b/scripts/travis-conf/trino/catalog/hive.properties new file mode 100644 index 00000000..5129f3c3 --- /dev/null +++ b/scripts/travis-conf/trino/catalog/hive.properties @@ -0,0 +1,2 @@ +connector.name=hive-hadoop2 +hive.metastore.uri=thrift://localhost:9083 diff --git a/scripts/travis-conf/trino/config.properties b/scripts/travis-conf/trino/config.properties new file mode 100644 index 00000000..dff1a087 --- /dev/null +++ b/scripts/travis-conf/trino/config.properties @@ -0,0 +1,7 @@ +coordinator=true +node-scheduler.include-coordinator=true +http-server.http.port=18080 +query.max-memory=100MB +query.max-memory-per-node=100MB +discovery-server.enabled=true +discovery.uri=http://localhost:18080 diff --git a/scripts/travis-conf/trino/jvm.config b/scripts/travis-conf/trino/jvm.config new file mode 100644 index 00000000..e69de29b diff --git a/scripts/travis-conf/trino/node.properties b/scripts/travis-conf/trino/node.properties new file mode 100644 index 00000000..8c1b8422 --- /dev/null +++ b/scripts/travis-conf/trino/node.properties @@ -0,0 +1,3 @@ +node.environment=production +node.id=11111111-1111-1111-1111-111111111111 +node.data-dir=/tmp/trino/data diff --git a/scripts/travis-install.sh b/scripts/travis-install.sh index c6a1f041..5bca8d98 100755 --- a/scripts/travis-install.sh +++ b/scripts/travis-install.sh @@ -62,6 +62,23 @@ cp -r $(dirname $0)/travis-conf/presto presto-server/etc /usr/bin/python2.7 presto-server/bin/launcher.py start +# +# Trino +# + +sudo apt-get -q install -y python # Use python2 for trino server + +mvn -q org.apache.maven.plugins:maven-dependency-plugin:3.0.0:copy \ + -Dartifact=io.trino:trino-server:${TRINO}:tar.gz \ + -DoutputDirectory=. +tar -x -z -f trino-server-*.tar.gz +rm -rf trino-server +mv trino-server-*/ trino-server + +cp -r $(dirname $0)/travis-conf/trino trino-server/etc + +/usr/bin/python2.7 trino-server/bin/launcher.py start + # # Python # @@ -73,3 +90,7 @@ pip install -r dev_requirements.txt # Sleep so Presto has time to start up. # Otherwise we might get 'No nodes available to run query' or 'Presto server is still initializing' while ! grep -q 'SERVER STARTED' /tmp/presto/data/var/log/server.log; do sleep 1; done + +# Sleep so Trino has time to start up. +# Otherwise we might get 'No nodes available to run query' or 'Presto server is still initializing' +while ! grep -q 'SERVER STARTED' /tmp/trino/data/var/log/server.log; do sleep 1; done diff --git a/setup.cfg b/setup.cfg index 7165989f..2a8d245d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ flake8-ignore = presto-server/** ALL pyhive/hive.py F405 pyhive/presto.py F405 + pyhive/trino.py F405 W503 filterwarnings = error diff --git a/setup.py b/setup.py index 9903d78e..df410dbc 100755 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ def run_tests(self): ], extras_require={ 'presto': ['requests>=1.0.0'], + 'trino': ['requests>=1.0.0'], 'hive': ['sasl>=0.2.1', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], 'sqlalchemy': ['sqlalchemy>=1.3.0'], 'kerberos': ['requests_kerberos>=0.12.0'], @@ -66,6 +67,7 @@ def run_tests(self): 'sqlalchemy.dialects': [ 'hive = pyhive.sqlalchemy_hive:HiveDialect', 'presto = pyhive.sqlalchemy_presto:PrestoDialect', + 'trino = pyhive.sqlalchemy_trino:TrinoDialect', ], } )