Skip to content

Commit

Permalink
override seed types (#708)
Browse files Browse the repository at this point in the history
* override seed types

* s/_columns/column_types/g

* pep8

* fix unit tests, add integration test

* add bq, snowflake tests for seed type overrides
  • Loading branch information
drewbanin authored Apr 23, 2018
1 parent e20796e commit 3567e20
Show file tree
Hide file tree
Showing 16 changed files with 158 additions and 18 deletions.
13 changes: 8 additions & 5 deletions dbt/adapters/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,8 @@ def convert_datetime_type(cls, agate_table, col_idx):
return "datetime"

@classmethod
def create_csv_table(cls, profile, schema, table_name, agate_table):
def create_csv_table(cls, profile, schema, table_name, agate_table,
column_override):
pass

@classmethod
Expand All @@ -502,17 +503,19 @@ def reset_csv_table(cls, profile, schema, table_name, agate_table,
cls.drop(profile, schema, table_name, "table")

@classmethod
def _agate_to_schema(cls, agate_table):
def _agate_to_schema(cls, agate_table, column_override):
bq_schema = []
for idx, col_name in enumerate(agate_table.column_names):
type_ = cls.convert_agate_type(agate_table, idx)
inferred_type = cls.convert_agate_type(agate_table, idx)
type_ = column_override.get(col_name, inferred_type)
bq_schema.append(
google.cloud.bigquery.SchemaField(col_name, type_))
return bq_schema

@classmethod
def load_csv_rows(cls, profile, schema, table_name, agate_table):
bq_schema = cls._agate_to_schema(agate_table)
def load_csv_rows(cls, profile, schema, table_name, agate_table,
column_override):
bq_schema = cls._agate_to_schema(agate_table, column_override)
dataset = cls.get_dataset(profile, schema, None)
table = dataset.table(table_name)
conn = cls.get_connection(profile, None)
Expand Down
16 changes: 10 additions & 6 deletions dbt/adapters/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def cancel_connection(cls, project, connection):
'`cancel_connection` is not implemented for this adapter!')

@classmethod
def create_csv_table(cls, profile, schema, table_name, agate_table):
def create_csv_table(cls, profile, schema, table_name, agate_table,
column_override):
raise dbt.exceptions.NotImplementedException(
'`create_csv_table` is not implemented for this adapter!')

Expand All @@ -110,7 +111,8 @@ def reset_csv_table(cls, profile, schema, table_name, agate_table,
'`reset_csv_table` is not implemented for this adapter!')

@classmethod
def load_csv_rows(cls, profile, schema, table_name, agate_table):
def load_csv_rows(cls, profile, schema, table_name, agate_table,
column_override):
raise dbt.exceptions.NotImplementedException(
'`load_csv_rows` is not implemented for this adapter!')

Expand Down Expand Up @@ -647,18 +649,20 @@ def quote_schema_and_table(cls, profile, schema, table, model_name=None):

@classmethod
def handle_csv_table(cls, profile, schema, table_name, agate_table,
full_refresh=False):
column_override, full_refresh=False):
existing = cls.query_for_existing(profile, schema)
existing_type = existing.get(table_name)
if existing_type and existing_type != "table":
raise dbt.exceptions.RuntimeException(
"Cannot seed to '{}', it is a view".format(table_name))
if existing_type:
cls.reset_csv_table(profile, schema, table_name, agate_table,
full_refresh=full_refresh)
column_override, full_refresh=full_refresh)
else:
cls.create_csv_table(profile, schema, table_name, agate_table)
cls.load_csv_rows(profile, schema, table_name, agate_table)
cls.create_csv_table(profile, schema, table_name, agate_table,
column_override)
cls.load_csv_rows(profile, schema, table_name, agate_table,
column_override)
cls.commit_if_has_connection(profile, None)

@classmethod
Expand Down
16 changes: 11 additions & 5 deletions dbt/adapters/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,26 +200,32 @@ def convert_time_type(cls, agate_table, col_idx):
return "time"

@classmethod
def create_csv_table(cls, profile, schema, table_name, agate_table):
def create_csv_table(cls, profile, schema, table_name, agate_table,
column_override):
col_sqls = []
for idx, col_name in enumerate(agate_table.column_names):
type_ = cls.convert_agate_type(agate_table, idx)
inferred_type = cls.convert_agate_type(agate_table, idx)
type_ = column_override.get(col_name, inferred_type)
col_sqls.append('{} {}'.format(col_name, type_))
sql = 'create table "{}"."{}" ({})'.format(schema, table_name,
", ".join(col_sqls))
return cls.add_query(profile, sql)

@classmethod
def reset_csv_table(cls, profile, schema, table_name, agate_table,
full_refresh=False):
column_override, full_refresh=False):
if full_refresh:
cls.drop_table(profile, schema, table_name, None)
cls.create_csv_table(profile, schema, table_name, agate_table)
cls.create_csv_table(profile, schema, table_name, agate_table,
column_override)
else:
cls.truncate(profile, schema, table_name)

@classmethod
def load_csv_rows(cls, profile, schema, table_name, agate_table):
def load_csv_rows(cls, profile, schema, table_name, agate_table,
column_override):
bindings = []
placeholders = []
cols_sql = ", ".join(c for c in agate_table.column_names)

for chunk in chunks(agate_table.rows, 10000):
Expand Down
1 change: 1 addition & 0 deletions dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Required('post-hook'): [hook_contract],
Required('pre-hook'): [hook_contract],
Required('vars'): dict,
Required('column_types'): dict,
}, extra=ALLOW_EXTRA)

parsed_node_contract = unparsed_node_contract.extend({
Expand Down
2 changes: 1 addition & 1 deletion dbt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class SourceConfig(object):
ConfigKeys = DBTConfigKeys

AppendListFields = ['pre-hook', 'post-hook']
ExtendDictFields = ['vars']
ExtendDictFields = ['vars', 'column_types']
ClobberFields = [
'schema',
'enabled',
Expand Down
3 changes: 3 additions & 0 deletions dbt/node_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,10 @@ def execute(self, compiled_node, existing_, flat_graph):
schema = compiled_node["schema"]
table_name = compiled_node["name"]
table = compiled_node["agate_table"]

column_override = compiled_node['config'].get('column_types', {})
self.adapter.handle_csv_table(self.profile, schema, table_name, table,
column_override,
full_refresh=dbt.flags.FULL_REFRESH)

if dbt.flags.FULL_REFRESH:
Expand Down
1 change: 1 addition & 0 deletions dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
'pre-hook',
'post-hook',
'vars',
'column_types',
'bind',
]

Expand Down
15 changes: 15 additions & 0 deletions test/integration/005_simple_seed_test/macros/schema_test.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

{% macro test_column_type(model, field, type) %}

{% set cols = adapter.get_columns_in_table(model.schema, model.name) %}

{% set col_types = {} %}
{% for col in cols %}
{% set _ = col_types.update({col.name: col.data_type}) %}
{% endfor %}

{% set val = 0 if col_types[field] == type else 1 %}

select {{ val }} as pass_fail

{% endmacro %}
6 changes: 6 additions & 0 deletions test/integration/005_simple_seed_test/models-bq/schema.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

seed_enabled:
constraints:
column_type:
- {field: id, type: 'FLOAT' }
- {field: birthday, type: 'STRING' }
6 changes: 6 additions & 0 deletions test/integration/005_simple_seed_test/models-pg/schema.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

seed_enabled:
constraints:
column_type:
- {field: id, type: 'character varying(255)' }
- {field: birthday, type: 'date' }
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

seed_enabled:
constraints:
column_type:
- {field: ID, type: 'FLOAT' }
- {field: BIRTHDAY, type: 'character varying(16777216)' }
Empty file.
87 changes: 87 additions & 0 deletions test/integration/005_simple_seed_test/test_seed_type_override.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from nose.plugins.attrib import attr
from test.integration.base import DBTIntegrationTest

class TestSimpleSeedColumnOverride(DBTIntegrationTest):

@property
def schema(self):
return "simple_seed_005"

@property
def project_config(self):
return {
"data-paths": ['test/integration/005_simple_seed_test/data-config'],
"macro-paths": ['test/integration/005_simple_seed_test/macros'],
"seeds": {
"test": {
"enabled": False,
"seed_enabled": {
"enabled": True,
"column_types": self.seed_types()
},
}
}
}

class TestSimpleSeedColumnOverridePostgres(TestSimpleSeedColumnOverride):
@property
def models(self):
return "test/integration/005_simple_seed_test/models-pg"

@property
def profile_config(self):
return self.postgres_profile()

def seed_types(self):
return {
"id": "text",
"birthday": "date",
}

@attr(type='postgres')
def test_simple_seed_with_column_override_snowflake(self):
self.run_dbt(["seed"])
self.run_dbt(["test"])


class TestSimpleSeedColumnOverrideSnowflake(TestSimpleSeedColumnOverride):
@property
def models(self):
return "test/integration/005_simple_seed_test/models-snowflake"

def seed_types(self):
return {
"id": "FLOAT",
"birthday": "TEXT",
}

@property
def profile_config(self):
return self.snowflake_profile()

@attr(type='snowflake')
def test_simple_seed_with_column_override_snowflake(self):
self.run_dbt(["seed"])
self.run_dbt(["test"])

class TestSimpleSeedColumnOverrideBQ(TestSimpleSeedColumnOverride):
@property
def models(self):
return "test/integration/005_simple_seed_test/models-bq"

def seed_types(self):
return {
"id": "FLOAT64",
"birthday": "STRING",
}

@property
def profile_config(self):
return self.bigquery_profile()

@attr(type='bigquery')
def test_simple_seed_with_column_override_bq(self):
self.run_dbt(["seed"])
self.run_dbt(["test"])


1 change: 0 additions & 1 deletion test/integration/005_simple_seed_test/test_simple_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,3 @@ def test_simple_seed_with_disabled(self):
self.run_dbt(["seed"])
self.assertTableDoesExist('seed_enabled')
self.assertTableDoesNotExist('seed_disabled')

1 change: 1 addition & 0 deletions test/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def setUp(self):
'post-hook': [],
'pre-hook': [],
'vars': {},
'column_types': {},
}

def test__prepend_ctes__already_has_cte(self):
Expand Down
2 changes: 2 additions & 0 deletions test/unit/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def setUp(self):
'post-hook': [],
'pre-hook': [],
'vars': {},
'column_types': {},
}

self.disabled_config = {
Expand All @@ -65,6 +66,7 @@ def setUp(self):
'post-hook': [],
'pre-hook': [],
'vars': {},
'column_types': {},
}

def test__single_model(self):
Expand Down

0 comments on commit 3567e20

Please sign in to comment.