forked from dropbox/PyHive
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. Inherit from presto 2. Add travis test script 3. Add test cases
- Loading branch information
Showing
12 changed files
with
365 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
"""Integration between SQLAlchemy and Trino. | ||
Some code based on | ||
https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py | ||
which is released under the MIT license. | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import unicode_literals | ||
|
||
import re | ||
from sqlalchemy import exc | ||
from sqlalchemy import types | ||
from sqlalchemy import util | ||
# TODO shouldn't use mysql type | ||
from sqlalchemy.databases import mysql | ||
from sqlalchemy.engine import default | ||
from sqlalchemy.sql import compiler | ||
from sqlalchemy.sql.compiler import SQLCompiler | ||
|
||
from pyhive import trino | ||
from pyhive.common import UniversalSet | ||
from pyhive.sqlalchemy_presto import PrestoDialect, PrestoCompiler, PrestoIdentifierPreparer | ||
|
||
class TrinoIdentifierPreparer(PrestoIdentifierPreparer): | ||
pass | ||
|
||
|
||
_type_map = { | ||
'boolean': types.Boolean, | ||
'tinyint': mysql.MSTinyInteger, | ||
'smallint': types.SmallInteger, | ||
'integer': types.Integer, | ||
'bigint': types.BigInteger, | ||
'real': types.Float, | ||
'double': types.Float, | ||
'varchar': types.String, | ||
'timestamp': types.TIMESTAMP, | ||
'date': types.DATE, | ||
'varbinary': types.VARBINARY, | ||
} | ||
|
||
|
||
class TrinoCompiler(PrestoCompiler): | ||
pass | ||
|
||
|
||
class TrinoTypeCompiler(PrestoCompiler): | ||
def visit_CLOB(self, type_, **kw): | ||
raise ValueError("Trino does not support the CLOB column type.") | ||
|
||
def visit_NCLOB(self, type_, **kw): | ||
raise ValueError("Trino does not support the NCLOB column type.") | ||
|
||
def visit_DATETIME(self, type_, **kw): | ||
raise ValueError("Trino does not support the DATETIME column type.") | ||
|
||
def visit_FLOAT(self, type_, **kw): | ||
return 'DOUBLE' | ||
|
||
def visit_TEXT(self, type_, **kw): | ||
if type_.length: | ||
return 'VARCHAR({:d})'.format(type_.length) | ||
else: | ||
return 'VARCHAR' | ||
|
||
|
||
class TrinoDialect(PrestoDialect): | ||
name = 'trino' | ||
|
||
@classmethod | ||
def dbapi(cls): | ||
return trino |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
"""Trino integration tests. | ||
These rely on having a Trino+Hadoop cluster set up. | ||
They also require a tables created by make_test_tables.sh. | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import unicode_literals | ||
|
||
import contextlib | ||
import os | ||
import requests | ||
|
||
from pyhive import exc | ||
from pyhive import trino | ||
from pyhive.tests.dbapi_test_case import DBAPITestCase | ||
from pyhive.tests.dbapi_test_case import with_cursor | ||
from pyhive.tests.test_presto import TestPresto | ||
import mock | ||
import unittest | ||
import datetime | ||
|
||
_HOST = 'localhost' | ||
_PORT = '18080' | ||
|
||
|
||
class TestTrino(TestPresto): | ||
__test__ = True | ||
|
||
def connect(self): | ||
return trino.connect(host=_HOST, port=_PORT, source=self.id()) | ||
|
||
def test_bad_protocol(self): | ||
self.assertRaisesRegexp(ValueError, 'Protocol must be', | ||
lambda: trino.connect('localhost', protocol='nonsense').cursor()) | ||
|
||
def test_escape_args(self): | ||
escaper = trino.TrinoParamEscaper() | ||
|
||
self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)), | ||
("date '2020-04-17'",)) | ||
self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), | ||
("timestamp '2020-04-17 12:00:00.123'",)) | ||
|
||
@with_cursor | ||
def test_description(self, cursor): | ||
cursor.execute('SELECT 1 AS foobar FROM one_row') | ||
self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)]) | ||
self.assertIsNotNone(cursor.last_query_id) | ||
|
||
@with_cursor | ||
def test_complex(self, cursor): | ||
cursor.execute('SELECT * FROM one_row_complex') | ||
# TODO Trino drops the union field | ||
|
||
tinyint_type = 'tinyint' | ||
smallint_type = 'smallint' | ||
float_type = 'real' | ||
self.assertEqual(cursor.description, [ | ||
('boolean', 'boolean', None, None, None, None, True), | ||
('tinyint', tinyint_type, None, None, None, None, True), | ||
('smallint', smallint_type, None, None, None, None, True), | ||
('int', 'integer', None, None, None, None, True), | ||
('bigint', 'bigint', None, None, None, None, True), | ||
('float', float_type, None, None, None, None, True), | ||
('double', 'double', None, None, None, None, True), | ||
('string', 'varchar', None, None, None, None, True), | ||
('timestamp', 'timestamp', None, None, None, None, True), | ||
('binary', 'varbinary', None, None, None, None, True), | ||
('array', 'array(integer)', None, None, None, None, True), | ||
('map', 'map(integer,integer)', None, None, None, None, True), | ||
('struct', 'row(a integer,b integer)', None, None, None, None, True), | ||
# ('union', 'varchar', None, None, None, None, True), | ||
('decimal', 'decimal(10,1)', None, None, None, None, True), | ||
]) | ||
rows = cursor.fetchall() | ||
expected = [( | ||
True, | ||
127, | ||
32767, | ||
2147483647, | ||
9223372036854775807, | ||
0.5, | ||
0.25, | ||
'a string', | ||
'1970-01-01 00:00:00.000', | ||
b'123', | ||
[1, 2], | ||
{"1": 2, "3": 4}, # Trino converts all keys to strings so that they're valid JSON | ||
[1, 2], # struct is returned as a list of elements | ||
# '{0:1}', | ||
'0.1', | ||
)] | ||
self.assertEqual(rows, expected) | ||
# catch unicode/str | ||
self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
"""DB-API implementation backed by Trino | ||
See http://www.python.org/dev/peps/pep-0249/ | ||
Many docstrings in this file are based on the PEP, which is in the public domain. | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import unicode_literals | ||
|
||
import logging | ||
|
||
import requests | ||
|
||
# Make all exceptions visible in this module per DB-API | ||
from pyhive.common import DBAPITypeObject | ||
from pyhive.exc import * # noqa | ||
from pyhive.presto import Connection as PrestoConnection, Cursor as PrestoCursor, PrestoParamEscaper | ||
|
||
try: # Python 3 | ||
import urllib.parse as urlparse | ||
except ImportError: # Python 2 | ||
import urlparse | ||
|
||
# PEP 249 module globals | ||
apilevel = '2.0' | ||
threadsafety = 2 # Threads may share the module and connections. | ||
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
class TrinoParamEscaper(PrestoParamEscaper): | ||
pass | ||
|
||
|
||
_escaper = TrinoParamEscaper() | ||
|
||
|
||
def connect(*args, **kwargs): | ||
"""Constructor for creating a connection to the database. See class :py:class:`Connection` for | ||
arguments. | ||
:returns: a :py:class:`Connection` object. | ||
""" | ||
return Connection(*args, **kwargs) | ||
|
||
|
||
class Connection(PrestoConnection): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
def cursor(self): | ||
"""Return a new :py:class:`Cursor` object using the connection.""" | ||
return Cursor(*self._args, **self._kwargs) | ||
|
||
|
||
class Cursor(PrestoCursor): | ||
"""These objects represent a database cursor, which is used to manage the context of a fetch | ||
operation. | ||
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately | ||
visible by other cursors or connections. | ||
""" | ||
|
||
def execute(self, operation, parameters=None): | ||
"""Prepare and execute a database operation (query or command). | ||
Return values are not defined. | ||
""" | ||
headers = { | ||
'X-Trino-Catalog': self._catalog, | ||
'X-Trino-Schema': self._schema, | ||
'X-Trino-Source': self._source, | ||
'X-Trino-User': self._username, | ||
} | ||
|
||
if self._session_props: | ||
headers['X-Trino-Session'] = ','.join( | ||
'{}={}'.format(propname, propval) | ||
for propname, propval in self._session_props.items() | ||
) | ||
|
||
# Prepare statement | ||
if parameters is None: | ||
sql = operation | ||
else: | ||
sql = operation % _escaper.escape_args(parameters) | ||
|
||
self._reset_state() | ||
|
||
self._state = self._STATE_RUNNING | ||
url = urlparse.urlunparse(( | ||
self._protocol, | ||
'{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None)) | ||
_logger.info('%s', sql) | ||
_logger.debug("Headers: %s", headers) | ||
response = self._requests_session.post( | ||
url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs) | ||
self._process_response(response) | ||
|
||
def _process_response(self, response): | ||
"""Given the JSON response from Trino's REST API, update the internal state with the next | ||
URI and any data from the response | ||
""" | ||
# TODO handle HTTP 503 | ||
if response.status_code != requests.codes.ok: | ||
fmt = "Unexpected status code {}\n{}" | ||
raise OperationalError(fmt.format(response.status_code, response.content)) | ||
|
||
response_json = response.json() | ||
_logger.debug("Got response %s", response_json) | ||
assert self._state == self._STATE_RUNNING, "Should be running if processing response" | ||
self._nextUri = response_json.get('nextUri') | ||
self._columns = response_json.get('columns') | ||
if 'id' in response_json: | ||
self.last_query_id = response_json['id'] | ||
if 'X-Trino-Clear-Session' in response.headers: | ||
propname = response.headers['X-Trino-Clear-Session'] | ||
self._session_props.pop(propname, None) | ||
if 'X-Trino-Set-Session' in response.headers: | ||
propname, propval = response.headers['X-Trino-Set-Session'].split('=', 1) | ||
self._session_props[propname] = propval | ||
if 'data' in response_json: | ||
assert self._columns | ||
new_data = response_json['data'] | ||
self._decode_binary(new_data) | ||
self._data += map(tuple, new_data) | ||
if 'nextUri' not in response_json: | ||
self._state = self._STATE_FINISHED | ||
if 'error' in response_json: | ||
raise DatabaseError(response_json['error']) | ||
|
||
|
||
# | ||
# Type Objects and Constructors | ||
# | ||
|
||
|
||
# See types in trino-main/src/main/java/com/facebook/trino/tuple/TupleInfo.java | ||
FIXED_INT_64 = DBAPITypeObject(['bigint']) | ||
VARIABLE_BINARY = DBAPITypeObject(['varchar']) | ||
DOUBLE = DBAPITypeObject(['double']) | ||
BOOLEAN = DBAPITypeObject(['boolean']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
connector.name=hive-hadoop2 | ||
hive.metastore.uri=thrift://localhost:9083 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
coordinator=true | ||
node-scheduler.include-coordinator=true | ||
http-server.http.port=18080 | ||
query.max-memory=100MB | ||
query.max-memory-per-node=100MB | ||
discovery-server.enabled=true | ||
discovery.uri=http://localhost:18080 |
Empty file.
Oops, something went wrong.