diff --git a/dbt/adapters/postgres/impl.py b/dbt/adapters/postgres/impl.py index 50c05b0f199..f9a69ad386c 100644 --- a/dbt/adapters/postgres/impl.py +++ b/dbt/adapters/postgres/impl.py @@ -53,6 +53,10 @@ def date_function(cls): def get_status(cls, cursor): return cursor.statusmessage + @classmethod + def get_credentials(cls, credentials): + return credentials + @classmethod def open_connection(cls, connection): if connection.get('state') == 'open': @@ -61,8 +65,10 @@ def open_connection(cls, connection): result = connection.copy() + base_credentials = connection.get('credentials', {}) + credentials = cls.get_credentials(base_credentials.copy()) + try: - credentials = connection.get('credentials', {}) handle = psycopg2.connect( dbname=credentials.get('dbname'), user=credentials.get('user'), diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 311e09d3923..319edf2bf45 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -2,13 +2,13 @@ from dbt.adapters.postgres import PostgresAdapter from dbt.logger import GLOBAL_LOGGER as logger # noqa - +import dbt.exceptions +import boto3 drop_lock = multiprocessing.Lock() class RedshiftAdapter(PostgresAdapter): - @classmethod def type(cls): return 'redshift' @@ -17,6 +17,69 @@ def type(cls): def date_function(cls): return 'getdate()' + @classmethod + def fetch_cluster_credentials(cls, db_user, db_name, cluster_id, + duration_s): + """Fetches temporary login credentials from AWS. The specified user + must already exist in the database, or else an error will occur""" + boto_client = boto3.client('redshift') + + try: + return boto_client.get_cluster_credentials( + DbUser=db_user, + DbName=db_name, + ClusterIdentifier=cluster_id, + DurationSeconds=duration_s, + AutoCreate=False) + + except boto_client.exceptions.ClientError as e: + raise dbt.exceptions.FailedToConnectException( + "Unable to get temporary Redshift cluster credentials: " + "{}".format(e)) + + @classmethod + def get_tmp_iam_cluster_credentials(cls, credentials): + cluster_id = credentials.get('cluster_id') + + # default via: + # boto3.readthedocs.io/en/latest/reference/services/redshift.html + iam_duration_s = credentials.get('iam_duration_seconds', 900) + + if not cluster_id: + raise dbt.exceptions.FailedToConnectException( + "'cluster_id' must be provided in profile if IAM " + "authentication method selected") + + cluster_creds = cls.fetch_cluster_credentials( + credentials.get('user'), + credentials.get('dbname'), + credentials.get('cluster_id'), + iam_duration_s, + ) + + # replace username and password with temporary redshift credentials + return dbt.utils.merge(credentials, { + 'user': cluster_creds.get('DbUser'), + 'pass': cluster_creds.get('DbPassword') + }) + + @classmethod + def get_credentials(cls, credentials): + method = credentials.get('method') + + # Support missing 'method' for backwards compatibility + if method == 'database' or method is None: + logger.debug("Connecting to Redshift using 'database' credentials") + return credentials + + elif method == 'iam': + logger.debug("Connecting to Redshift using 'IAM' credentials") + return cls.get_tmp_iam_cluster_credentials(credentials) + + else: + raise dbt.exceptions.FailedToConnectException( + "Invalid 'method' in profile: '{}'".format(method)) + @classmethod def _get_columns_in_table_sql(cls, schema_name, table_name, database): # Redshift doesn't support cross-database queries, @@ -27,7 +90,7 @@ def _get_columns_in_table_sql(cls, schema_name, table_name, database): table_schema_filter = '1=1' else: table_schema_filter = "table_schema = '{schema_name}'".format( - schema_name=schema_name) + schema_name=schema_name) sql = """ with bound_views as ( diff --git a/dbt/contracts/connection.py b/dbt/contracts/connection.py index 1ac374ea095..2183c5d891b 100644 --- a/dbt/contracts/connection.py +++ b/dbt/contracts/connection.py @@ -36,6 +36,61 @@ 'required': ['dbname', 'host', 'user', 'pass', 'port', 'schema'], } +REDSHIFT_CREDENTIALS_CONTRACT = { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'method': { + 'enum': ['database', 'iam'], + 'description': ( + 'database: use user/pass creds; iam: use temporary creds' + ), + }, + 'dbname': { + 'type': 'string', + }, + 'host': { + 'type': 'string', + }, + 'user': { + 'type': 'string', + }, + 'pass': { + 'type': 'string', + }, + 'port': { + 'oneOf': [ + { + 'type': 'integer', + 'minimum': 0, + 'maximum': 65535, + }, + { + 'type': 'string' + }, + ], + }, + 'schema': { + 'type': 'string', + }, + 'cluster_id': { + 'type': 'string', + 'description': ( + 'If using IAM auth, the name of the cluster' + ) + }, + 'iam_duration_seconds': { + 'type': 'integer', + 'minimum': 900, + 'maximum': 3600, + 'description': ( + 'If using IAM auth, the ttl for the temporary credentials' + ) + }, + 'required': ['dbname', 'host', 'user', 'port', 'schema'] + } +} + SNOWFLAKE_CREDENTIALS_CONTRACT = { 'type': 'object', 'additionalProperties': False, @@ -113,11 +168,11 @@ }, 'credentials': { 'description': ( - 'The credentials object here should match the connection ' - 'type. Redshift uses the Postgres connection model.' + 'The credentials object here should match the connection type.' ), - 'oneOf': [ + 'anyOf': [ POSTGRES_CREDENTIALS_CONTRACT, + REDSHIFT_CREDENTIALS_CONTRACT, SNOWFLAKE_CREDENTIALS_CONTRACT, BIGQUERY_CREDENTIALS_CONTRACT, ], @@ -133,6 +188,10 @@ class PostgresCredentials(APIObject): SCHEMA = POSTGRES_CREDENTIALS_CONTRACT +class RedshiftCredentials(APIObject): + SCHEMA = REDSHIFT_CREDENTIALS_CONTRACT + + class SnowflakeCredentials(APIObject): SCHEMA = SNOWFLAKE_CREDENTIALS_CONTRACT @@ -143,7 +202,7 @@ class BigQueryCredentials(APIObject): CREDENTIALS_MAPPING = { 'postgres': PostgresCredentials, - 'redshift': PostgresCredentials, + 'redshift': RedshiftCredentials, 'snowflake': SnowflakeCredentials, 'bigquery': BigQueryCredentials, } diff --git a/requirements.txt b/requirements.txt index 5c2e7de65d2..5b59f615305 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ google-cloud-bigquery==0.29.0 requests>=2.18.0 agate>=1.6,<2 jsonschema==2.6.0 +boto3>=1.6.23 diff --git a/setup.py b/setup.py index 04c4c443fe1..b6e69f6749e 100644 --- a/setup.py +++ b/setup.py @@ -52,5 +52,6 @@ def read(fname): 'google-cloud-bigquery==0.29.0', 'agate>=1.6,<2', 'jsonschema==2.6.0', + 'boto3>=1.6.23' ] ) diff --git a/test/unit/test_redshift_adapter.py b/test/unit/test_redshift_adapter.py new file mode 100644 index 00000000000..505ae4b536d --- /dev/null +++ b/test/unit/test_redshift_adapter.py @@ -0,0 +1,98 @@ +import unittest +import mock + +import dbt.flags as flags +import dbt.utils + +from dbt.adapters.redshift import RedshiftAdapter +from dbt.exceptions import ValidationException, FailedToConnectException +from dbt.logger import GLOBAL_LOGGER as logger # noqa + +@classmethod +def fetch_cluster_credentials(*args, **kwargs): + return { + 'DbUser': 'root', + 'DbPassword': 'tmp_password' + } + +class TestRedshiftAdapter(unittest.TestCase): + + def setUp(self): + flags.STRICT_MODE = True + + def test_implicit_database_conn(self): + implicit_database_profile = { + 'dbname': 'redshift', + 'user': 'root', + 'host': 'database', + 'pass': 'password', + 'port': 5439, + 'schema': 'public' + } + + creds = RedshiftAdapter.get_credentials(implicit_database_profile) + self.assertEquals(creds, implicit_database_profile) + + def test_explicit_database_conn(self): + explicit_database_profile = { + 'method': 'database', + 'dbname': 'redshift', + 'user': 'root', + 'host': 'database', + 'pass': 'password', + 'port': 5439, + 'schema': 'public' + } + + creds = RedshiftAdapter.get_credentials(explicit_database_profile) + self.assertEquals(creds, explicit_database_profile) + + def test_explicit_iam_conn(self): + explicit_iam_profile = { + 'method': 'iam', + 'cluster_id': 'my_redshift', + 'iam_duration_s': 1200, + 'dbname': 'redshift', + 'user': 'root', + 'host': 'database', + 'port': 5439, + 'schema': 'public', + } + + with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): + creds = RedshiftAdapter.get_credentials(explicit_iam_profile) + + expected_creds = dbt.utils.merge(explicit_iam_profile, {'pass': 'tmp_password'}) + self.assertEquals(creds, expected_creds) + + def test_invalid_auth_method(self): + invalid_profile = { + 'method': 'badmethod', + 'dbname': 'redshift', + 'user': 'root', + 'host': 'database', + 'pass': 'password', + 'port': 5439, + 'schema': 'public' + } + + with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: + with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): + RedshiftAdapter.get_credentials(invalid_profile) + + self.assertTrue('badmethod' in context.exception.msg) + + def test_invalid_iam_no_cluster_id(self): + invalid_profile = { + 'method': 'iam', + 'dbname': 'redshift', + 'user': 'root', + 'host': 'database', + 'port': 5439, + 'schema': 'public' + } + with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: + with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): + RedshiftAdapter.get_credentials(invalid_profile) + + self.assertTrue("'cluster_id' must be provided" in context.exception.msg)