Skip to content

Commit

Permalink
Merge pull request #818 from fishtown-analytics/feature/redshift-iam-…
Browse files Browse the repository at this point in the history
…auth

Feature/redshift iam auth
  • Loading branch information
drewbanin authored Jul 4, 2018
2 parents 5d9b8c5 + e7abe27 commit 145a82b
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 8 deletions.
8 changes: 7 additions & 1 deletion dbt/adapters/postgres/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def date_function(cls):
def get_status(cls, cursor):
return cursor.statusmessage

@classmethod
def get_credentials(cls, credentials):
return credentials

@classmethod
def open_connection(cls, connection):
if connection.get('state') == 'open':
Expand All @@ -61,8 +65,10 @@ def open_connection(cls, connection):

result = connection.copy()

base_credentials = connection.get('credentials', {})
credentials = cls.get_credentials(base_credentials.copy())

try:
credentials = connection.get('credentials', {})
handle = psycopg2.connect(
dbname=credentials.get('dbname'),
user=credentials.get('user'),
Expand Down
69 changes: 66 additions & 3 deletions dbt/adapters/redshift/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from dbt.adapters.postgres import PostgresAdapter
from dbt.logger import GLOBAL_LOGGER as logger # noqa

import dbt.exceptions
import boto3

drop_lock = multiprocessing.Lock()


class RedshiftAdapter(PostgresAdapter):

@classmethod
def type(cls):
return 'redshift'
Expand All @@ -17,6 +17,69 @@ def type(cls):
def date_function(cls):
return 'getdate()'

@classmethod
def fetch_cluster_credentials(cls, db_user, db_name, cluster_id,
duration_s):
"""Fetches temporary login credentials from AWS. The specified user
must already exist in the database, or else an error will occur"""
boto_client = boto3.client('redshift')

try:
return boto_client.get_cluster_credentials(
DbUser=db_user,
DbName=db_name,
ClusterIdentifier=cluster_id,
DurationSeconds=duration_s,
AutoCreate=False)

except boto_client.exceptions.ClientError as e:
raise dbt.exceptions.FailedToConnectException(
"Unable to get temporary Redshift cluster credentials: "
"{}".format(e))

@classmethod
def get_tmp_iam_cluster_credentials(cls, credentials):
cluster_id = credentials.get('cluster_id')

# default via:
# boto3.readthedocs.io/en/latest/reference/services/redshift.html
iam_duration_s = credentials.get('iam_duration_seconds', 900)

if not cluster_id:
raise dbt.exceptions.FailedToConnectException(
"'cluster_id' must be provided in profile if IAM "
"authentication method selected")

cluster_creds = cls.fetch_cluster_credentials(
credentials.get('user'),
credentials.get('dbname'),
credentials.get('cluster_id'),
iam_duration_s,
)

# replace username and password with temporary redshift credentials
return dbt.utils.merge(credentials, {
'user': cluster_creds.get('DbUser'),
'pass': cluster_creds.get('DbPassword')
})

@classmethod
def get_credentials(cls, credentials):
method = credentials.get('method')

# Support missing 'method' for backwards compatibility
if method == 'database' or method is None:
logger.debug("Connecting to Redshift using 'database' credentials")
return credentials

elif method == 'iam':
logger.debug("Connecting to Redshift using 'IAM' credentials")
return cls.get_tmp_iam_cluster_credentials(credentials)

else:
raise dbt.exceptions.FailedToConnectException(
"Invalid 'method' in profile: '{}'".format(method))

@classmethod
def _get_columns_in_table_sql(cls, schema_name, table_name, database):
# Redshift doesn't support cross-database queries,
Expand All @@ -27,7 +90,7 @@ def _get_columns_in_table_sql(cls, schema_name, table_name, database):
table_schema_filter = '1=1'
else:
table_schema_filter = "table_schema = '{schema_name}'".format(
schema_name=schema_name)
schema_name=schema_name)

sql = """
with bound_views as (
Expand Down
67 changes: 63 additions & 4 deletions dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,61 @@
'required': ['dbname', 'host', 'user', 'pass', 'port', 'schema'],
}

REDSHIFT_CREDENTIALS_CONTRACT = {
'type': 'object',
'additionalProperties': False,
'properties': {
'method': {
'enum': ['database', 'iam'],
'description': (
'database: use user/pass creds; iam: use temporary creds'
),
},
'dbname': {
'type': 'string',
},
'host': {
'type': 'string',
},
'user': {
'type': 'string',
},
'pass': {
'type': 'string',
},
'port': {
'oneOf': [
{
'type': 'integer',
'minimum': 0,
'maximum': 65535,
},
{
'type': 'string'
},
],
},
'schema': {
'type': 'string',
},
'cluster_id': {
'type': 'string',
'description': (
'If using IAM auth, the name of the cluster'
)
},
'iam_duration_seconds': {
'type': 'integer',
'minimum': 900,
'maximum': 3600,
'description': (
'If using IAM auth, the ttl for the temporary credentials'
)
},
'required': ['dbname', 'host', 'user', 'port', 'schema']
}
}

SNOWFLAKE_CREDENTIALS_CONTRACT = {
'type': 'object',
'additionalProperties': False,
Expand Down Expand Up @@ -113,11 +168,11 @@
},
'credentials': {
'description': (
'The credentials object here should match the connection '
'type. Redshift uses the Postgres connection model.'
'The credentials object here should match the connection type.'
),
'oneOf': [
'anyOf': [
POSTGRES_CREDENTIALS_CONTRACT,
REDSHIFT_CREDENTIALS_CONTRACT,
SNOWFLAKE_CREDENTIALS_CONTRACT,
BIGQUERY_CREDENTIALS_CONTRACT,
],
Expand All @@ -133,6 +188,10 @@ class PostgresCredentials(APIObject):
SCHEMA = POSTGRES_CREDENTIALS_CONTRACT


class RedshiftCredentials(APIObject):
SCHEMA = REDSHIFT_CREDENTIALS_CONTRACT


class SnowflakeCredentials(APIObject):
SCHEMA = SNOWFLAKE_CREDENTIALS_CONTRACT

Expand All @@ -143,7 +202,7 @@ class BigQueryCredentials(APIObject):

CREDENTIALS_MAPPING = {
'postgres': PostgresCredentials,
'redshift': PostgresCredentials,
'redshift': RedshiftCredentials,
'snowflake': SnowflakeCredentials,
'bigquery': BigQueryCredentials,
}
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ google-cloud-bigquery==0.29.0
requests>=2.18.0
agate>=1.6,<2
jsonschema==2.6.0
boto3>=1.6.23
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,6 @@ def read(fname):
'google-cloud-bigquery==0.29.0',
'agate>=1.6,<2',
'jsonschema==2.6.0',
'boto3>=1.6.23'
]
)
98 changes: 98 additions & 0 deletions test/unit/test_redshift_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import unittest
import mock

import dbt.flags as flags
import dbt.utils

from dbt.adapters.redshift import RedshiftAdapter
from dbt.exceptions import ValidationException, FailedToConnectException
from dbt.logger import GLOBAL_LOGGER as logger # noqa

@classmethod
def fetch_cluster_credentials(*args, **kwargs):
return {
'DbUser': 'root',
'DbPassword': 'tmp_password'
}

class TestRedshiftAdapter(unittest.TestCase):

def setUp(self):
flags.STRICT_MODE = True

def test_implicit_database_conn(self):
implicit_database_profile = {
'dbname': 'redshift',
'user': 'root',
'host': 'database',
'pass': 'password',
'port': 5439,
'schema': 'public'
}

creds = RedshiftAdapter.get_credentials(implicit_database_profile)
self.assertEquals(creds, implicit_database_profile)

def test_explicit_database_conn(self):
explicit_database_profile = {
'method': 'database',
'dbname': 'redshift',
'user': 'root',
'host': 'database',
'pass': 'password',
'port': 5439,
'schema': 'public'
}

creds = RedshiftAdapter.get_credentials(explicit_database_profile)
self.assertEquals(creds, explicit_database_profile)

def test_explicit_iam_conn(self):
explicit_iam_profile = {
'method': 'iam',
'cluster_id': 'my_redshift',
'iam_duration_s': 1200,
'dbname': 'redshift',
'user': 'root',
'host': 'database',
'port': 5439,
'schema': 'public',
}

with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials):
creds = RedshiftAdapter.get_credentials(explicit_iam_profile)

expected_creds = dbt.utils.merge(explicit_iam_profile, {'pass': 'tmp_password'})
self.assertEquals(creds, expected_creds)

def test_invalid_auth_method(self):
invalid_profile = {
'method': 'badmethod',
'dbname': 'redshift',
'user': 'root',
'host': 'database',
'pass': 'password',
'port': 5439,
'schema': 'public'
}

with self.assertRaises(dbt.exceptions.FailedToConnectException) as context:
with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials):
RedshiftAdapter.get_credentials(invalid_profile)

self.assertTrue('badmethod' in context.exception.msg)

def test_invalid_iam_no_cluster_id(self):
invalid_profile = {
'method': 'iam',
'dbname': 'redshift',
'user': 'root',
'host': 'database',
'port': 5439,
'schema': 'public'
}
with self.assertRaises(dbt.exceptions.FailedToConnectException) as context:
with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials):
RedshiftAdapter.get_credentials(invalid_profile)

self.assertTrue("'cluster_id' must be provided" in context.exception.msg)

0 comments on commit 145a82b

Please sign in to comment.