diff --git a/.travis.yml b/.travis.yml index 30814a5a..fd3de6b3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,6 +15,7 @@ matrix: env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy>=1.3.0 - python: 2.7 env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 SQLALCHEMY=sqlalchemy>=1.3.0 + install: - ./scripts/travis-install.sh - pip install codecov diff --git a/pyhive/hive.py b/pyhive/hive.py index a8635bac..8ceeaa9f 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -8,6 +8,7 @@ from __future__ import absolute_import from __future__ import unicode_literals +import base64 import datetime import re from decimal import Decimal @@ -28,6 +29,7 @@ import thrift.protocol.TBinaryProtocol import thrift.transport.TSocket import thrift.transport.TTransport +import thrift.transport.THttpClient # PEP 249 module globals apilevel = '2.0' @@ -94,12 +96,24 @@ def connect(*args, **kwargs): return Connection(*args, **kwargs) +# TODO +# Setting the Cookie in the headers should be implemented in the thrift library. +# We'll keep this here until that change is available in there. +class TCookieHttpClient(thrift.transport.THttpClient.THttpClient): + def flush(self): + super(TCookieHttpClient, self).flush() + + if 'Set-Cookie' in self.headers: + self.setCustomHeaders( + {'Cookie': self.headers['Set-Cookie']}) + + class Connection(object): """Wraps a Thrift session""" def __init__(self, host=None, port=None, username=None, database='default', auth=None, configuration=None, kerberos_service_name=None, password=None, - thrift_transport=None): + thrift_transport=None, thrift_transport_protocol='binary', http_path=None): """Connect to HiveServer2 :param host: What host HiveServer2 runs on @@ -119,9 +133,6 @@ def __init__(self, host=None, port=None, username=None, database='default', auth username = username or getpass.getuser() configuration = configuration or {} - if (password is not None) != (auth in ('LDAP', 'CUSTOM')): - raise ValueError("Password should be set if and only if in LDAP or CUSTOM mode; " - "Remove password or use one of those modes") if (kerberos_service_name is not None) != (auth == 'KERBEROS'): raise ValueError("kerberos_service_name should be set if and only if in KERBEROS mode") if thrift_transport is not None: @@ -138,51 +149,29 @@ def __init__(self, host=None, port=None, username=None, database='default', auth if thrift_transport is not None: self._transport = thrift_transport + elif thrift_transport_protocol == 'binary': + self._transport = Connection. \ + create_binary_transport(host=host, + port=port, + username=username, + password=password, + kerberos_service_name=kerberos_service_name, + auth=auth) + elif thrift_transport_protocol == 'http': + self._transport = Connection.\ + create_http_transport(host=host, + username=username, + port=port, + http_path=http_path, + password=password, + kerberos_service_name=kerberos_service_name, + auth=auth) else: - if port is None: - port = 10000 - if auth is None: - auth = 'NONE' - socket = thrift.transport.TSocket.TSocket(host, port) - if auth == 'NOSASL': - # NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml - self._transport = thrift.transport.TTransport.TBufferedTransport(socket) - elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): - # Defer import so package dependency is optional - import sasl - import thrift_sasl - - if auth == 'KERBEROS': - # KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library - sasl_auth = 'GSSAPI' - else: - sasl_auth = 'PLAIN' - if password is None: - # Password doesn't matter in NONE mode, just needs to be nonempty. - password = 'x' - - def sasl_factory(): - sasl_client = sasl.Client() - sasl_client.setAttr('host', host) - if sasl_auth == 'GSSAPI': - sasl_client.setAttr('service', kerberos_service_name) - elif sasl_auth == 'PLAIN': - sasl_client.setAttr('username', username) - sasl_client.setAttr('password', password) - else: - raise AssertionError - sasl_client.init() - return sasl_client - self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) - else: - # All HS2 config options: - # https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration - # PAM currently left to end user via thrift_transport option. - raise NotImplementedError( - "Only NONE, NOSASL, LDAP, KERBEROS, CUSTOM " - "authentication are supported, got {}".format(auth)) - - protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport) + raise ValueError("Invalid thrift_transport_protocol: {}".format( + thrift_transport_protocol)) + + protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol( + self._transport) self._client = TCLIService.Client(protocol) # oldest version that still contains features we care about # "V6 uses binary type for binary payload (was string) and uses columnar result set" @@ -207,6 +196,109 @@ def sasl_factory(): self._transport.close() raise + @staticmethod + def create_http_transport(host, port, http_path, username, password, + kerberos_service_name, auth): + if port is None: + port = 10001 + if auth is None: + auth = 'NONE' + if http_path is None: + http_path = '/' + + socket = TCookieHttpClient('http://{}:{}{}'.format(host, port, http_path)) + + if auth == 'KERBEROS': + import kerberos + + __, krb_context = kerberos.authGSSClientInit( + service='{}@{}'.format(kerberos_service_name, host)) + kerberos.authGSSClientClean(krb_context, '') + kerberos.authGSSClientStep(krb_context, '') + auth_header = kerberos.authGSSClientResponse(krb_context) + + socket.setCustomHeaders( + {'Authorization': 'Negotiate {}'.format(auth_header)}) + + elif auth in ('BASIC', 'NOSASL', 'NONE'): + if auth == 'BASIC' and password is None: + raise ValueError("BASIC authentication requires password.") + + auth_credentials = '{}:{}'.format(username, password)\ + .encode('UTF-8') + auth_credentials_base64 = base64.standard_b64encode( + auth_credentials).decode('UTF-8') + + # we're using the Authorization header for auth NONE or NOSASL because that's where + # Hive gets the username when doAs is enabled + socket.setCustomHeaders( + {'Authorization': 'Basic {}'.format(auth_credentials_base64)}) + + else: + raise NotImplementedError( + "Only NONE, NOSASL, BASIC and KERBEROS authentication are supported " + "when using HTTP transport, got {}".format(auth)) + + return thrift.transport.TTransport.TBufferedTransport(socket) + + @staticmethod + def create_binary_transport(host, port, username, password, kerberos_service_name, auth): + + if port is None: + port = 10000 + if auth is None: + auth = 'NONE' + + if (password is not None) != (auth in ('LDAP', 'CUSTOM')): + raise ValueError( + "Password should be set if and only if in LDAP or CUSTOM mode; " + "Remove password or use one of those modes") + + socket = thrift.transport.TSocket.TSocket(host, port) + + if auth == 'NOSASL': + # NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml + transport = thrift.transport.TTransport.TBufferedTransport(socket) + elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): + # Defer import so package dependency is optional + import sasl + import thrift_sasl + + if auth == 'KERBEROS': + # KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library + sasl_auth = 'GSSAPI' + else: + sasl_auth = 'PLAIN' + if password is None: + # Password doesn't matter in NONE mode, just needs to be nonempty. + password = 'x' + + def sasl_factory(): + sasl_client = sasl.Client() + sasl_client.setAttr('host', host) + if sasl_auth == 'GSSAPI': + sasl_client.setAttr('service', kerberos_service_name) + elif sasl_auth == 'PLAIN': + sasl_client.setAttr('username', username) + sasl_client.setAttr('password', password) + else: + raise AssertionError + sasl_client.init() + return sasl_client + + transport = thrift_sasl.TSaslClientTransport(sasl_factory, + sasl_auth, + socket) + else: + # All HS2 config options: + # https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration + # PAM currently left to end user via thrift_transport option. + raise NotImplementedError( + "Only NONE, NOSASL, LDAP, KERBEROS, CUSTOM " + "authentication are supported, got {}".format(auth)) + + return transport + def __enter__(self): """Transport should already be opened by __init__""" return self diff --git a/pyhive/tests/test_hive.py b/pyhive/tests/test_hive.py index c70ed962..eb0e995a 100644 --- a/pyhive/tests/test_hive.py +++ b/pyhive/tests/test_hive.py @@ -200,6 +200,48 @@ def test_invalid_transport(self): lambda: hive.connect(_HOST, thrift_transport=transport) ) + def test_invalid_binary_auth(self): + invalid_binary_auth = 'invalid' + self.assertRaisesRegexp( + NotImplementedError, + 'Only NONE, NOSASL, LDAP, KERBEROS, CUSTOM authentication are supported, ' + 'got {}'.format(invalid_binary_auth), + lambda: hive.connect(host=_HOST, auth=invalid_binary_auth) + ) + + def test_invalid_transport_protocol(self): + invalid_transport = 'invalid' + self.assertRaisesRegexp( + ValueError, + 'Invalid thrift_transport_protocol: {}'.format(invalid_transport), + lambda: hive.connect(host=_HOST, thrift_transport_protocol=invalid_transport) + ) + + def test_invalid_http_basic_auth(self): + self.assertRaisesRegexp( + ValueError, + 'BASIC authentication requires password.', + lambda: hive.connect(host=_HOST, thrift_transport_protocol='http', + auth='BASIC') + ) + + self.assertRaisesRegexp( + ValueError, + 'BASIC authentication requires password.', + lambda: hive.connect(host=_HOST, thrift_transport_protocol='http', + auth='BASIC', username='username') + ) + + def test_invalid_http_auth(self): + thrift_transport_protocol = 'http' + auth = 'LDAP' + self.assertRaisesRegexp( + NotImplementedError, + "Only NONE, NOSASL, BASIC and KERBEROS authentication are supported " + "when using HTTP transport, got {}".format(auth), + lambda: hive.connect(thrift_transport_protocol=thrift_transport_protocol, auth=auth) + ) + def test_custom_transport(self): socket = thrift.transport.TSocket.TSocket('localhost', 10000) sasl_auth = 'PLAIN' @@ -244,9 +286,127 @@ def test_custom_connection(self): subprocess.check_call(['sudo', 'cp', orig_none, des]) _restart_hs2() + def test_thrift_nosasl(self): + rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + orig_ldap = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site-nosasl.xml') + orig_none = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site.xml') + des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml') + try: + subprocess.check_call(['sudo', 'cp', orig_ldap, des]) + _restart_hs2() + with contextlib.closing(hive.connect( + host=_HOST, auth='NOSASL') + ) as connection: + with contextlib.closing(connection.cursor()) as cursor: + cursor.execute('SELECT * FROM one_row') + self.assertEqual(cursor.fetchall(), [(1,)]) + + finally: + subprocess.check_call(['sudo', 'cp', orig_none, des]) + _restart_hs2() + + def test_thrift_http_auth_none(self): + rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + orig_http = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', + 'hive-site-http-none.xml') + orig_none = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site.xml') + des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml') + try: + subprocess.check_call(['sudo', 'cp', orig_http, des]) + _restart_hs2(10001) + + with contextlib.closing(hive.connect( + host=_HOST, username='the-user', thrift_transport_protocol='http', + auth='NONE', http_path='/') + ) as connection: + with contextlib.closing(connection.cursor()) as cursor: + cursor.execute('SELECT * FROM one_row') + self.assertEqual(cursor.fetchall(), [(1,)]) + + with contextlib.closing(hive.connect( + host=_HOST, thrift_transport_protocol='http') + ) as connection: + with contextlib.closing(connection.cursor()) as cursor: + cursor.execute('SELECT * FROM one_row') + self.assertEqual(cursor.fetchall(), [(1,)]) + + finally: + subprocess.check_call(['sudo', 'cp', orig_none, des]) + _restart_hs2() + + def test_thrift_http_auth_none_with_path(self): + rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + orig_http = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', + 'hive-site-http-none-path.xml') + orig_none = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site.xml') + des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml') + try: + subprocess.check_call(['sudo', 'cp', orig_http, des]) + _restart_hs2(10001) + + with contextlib.closing(hive.connect( + host=_HOST, username='the-user', thrift_transport_protocol='http', + http_path="/servicepath", auth='NONE') + ) as connection: + with contextlib.closing(connection.cursor()) as cursor: + cursor.execute('SELECT * FROM one_row') + self.assertEqual(cursor.fetchall(), [(1,)]) + + finally: + subprocess.check_call(['sudo', 'cp', orig_none, des]) + _restart_hs2() + + def test_thrift_http_auth_nosasl(self): + rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + orig_http = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', + 'hive-site-http-nosasl.xml') + orig_none = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site.xml') + des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml') + try: + subprocess.check_call(['sudo', 'cp', orig_http, des]) + _restart_hs2(10001) + + with contextlib.closing(hive.connect( + host=_HOST, username='the-user', thrift_transport_protocol='http', + auth='NOSASL') + ) as connection: + with contextlib.closing(connection.cursor()) as cursor: + cursor.execute('SELECT * FROM one_row') + self.assertEqual(cursor.fetchall(), [(1,)]) + + finally: + subprocess.check_call(['sudo', 'cp', orig_none, des]) + _restart_hs2() + + @mock.patch('pyhive.hive.TCookieHttpClient') + def test_thrift_http_auth_kerberos(self, mock_tcookiehttpclient): + from pyhive.hive import Connection + import kerberos + + dummy_ctx = "ctx" + kerberos.authGSSClientInit = mock.MagicMock(return_value=(None, dummy_ctx)) + + mock_clean = mock.create_autospec(lambda *args, **kwargs: None) + kerberos.authGSSClientClean = mock_clean + + mock_step = mock.create_autospec(lambda *args, **kwargs: None) + kerberos.authGSSClientStep = mock_step + + kerberos.authGSSClientResponse = mock.create_autospec( + lambda ctx: self.assertEqual(ctx, dummy_ctx), return_value="auth_header_value") + + mock_tcookiehttpclient.setCustomHeaders = mock.create_autospec( + lambda __, headers: self.assertEqual(headers['Authorization'], + 'Negotiate auth_header_value')) + + Connection.create_http_transport(_HOST, None, None, None, None, "hive", "KERBEROS") + + mock_clean.assert_called_once_with(dummy_ctx, '') + mock_step.assert_called_once_with(dummy_ctx, '') + -def _restart_hs2(): +def _restart_hs2(port=10000): subprocess.check_call(['sudo', 'service', 'hive-server2', 'restart']) with contextlib.closing(socket.socket()) as s: - while s.connect_ex(('localhost', 10000)) != 0: + while s.connect_ex((_HOST, port)) != 0: time.sleep(1) diff --git a/scripts/travis-conf/hive/hive-site-http-none-path.xml b/scripts/travis-conf/hive/hive-site-http-none-path.xml new file mode 100644 index 00000000..a5a56a7c --- /dev/null +++ b/scripts/travis-conf/hive/hive-site-http-none-path.xml @@ -0,0 +1,24 @@ + + + + + hive.metastore.uris + thrift://localhost:9083 + + + javax.jdo.option.ConnectionURL + jdbc:derby:;databaseName=/var/lib/hive/metastore/metastore_db;create=true + + + fs.defaultFS + file:/// + + + hive.server2.transport.mode + http + + + hive.server2.thrift.http.path + /servicepath + + diff --git a/scripts/travis-conf/hive/hive-site-http-none.xml b/scripts/travis-conf/hive/hive-site-http-none.xml new file mode 100644 index 00000000..0479930f --- /dev/null +++ b/scripts/travis-conf/hive/hive-site-http-none.xml @@ -0,0 +1,24 @@ + + + + + hive.metastore.uris + thrift://localhost:9083 + + + javax.jdo.option.ConnectionURL + jdbc:derby:;databaseName=/var/lib/hive/metastore/metastore_db;create=true + + + fs.defaultFS + file:/// + + + hive.server2.transport.mode + http + + + hive.server2.thrift.http.path + / + + diff --git a/scripts/travis-conf/hive/hive-site-http-nosasl.xml b/scripts/travis-conf/hive/hive-site-http-nosasl.xml new file mode 100644 index 00000000..8ce562c2 --- /dev/null +++ b/scripts/travis-conf/hive/hive-site-http-nosasl.xml @@ -0,0 +1,28 @@ + + + + + hive.metastore.uris + thrift://localhost:9083 + + + javax.jdo.option.ConnectionURL + jdbc:derby:;databaseName=/var/lib/hive/metastore/metastore_db;create=true + + + fs.defaultFS + file:/// + + + hive.server2.transport.mode + http + + + hive.server2.thrift.http.path + / + + + hive.server2.authentication + NOSASL + + diff --git a/scripts/travis-conf/hive/hive-site-nosasl.xml b/scripts/travis-conf/hive/hive-site-nosasl.xml new file mode 100644 index 00000000..414242dc --- /dev/null +++ b/scripts/travis-conf/hive/hive-site-nosasl.xml @@ -0,0 +1,20 @@ + + + + + hive.metastore.uris + thrift://localhost:9083 + + + javax.jdo.option.ConnectionURL + jdbc:derby:;databaseName=/var/lib/hive/metastore/metastore_db;create=true + + + fs.defaultFS + file:/// + + + hive.server2.authentication + NOSASL + +