Skip to content

Commit

Permalink
Merge pull request #1614 from fishtown-analytics/fix/snapshot-check-c…
Browse files Browse the repository at this point in the history
…ols-cycle

possible fix for re-used check cols on BQ
  • Loading branch information
drewbanin authored Jul 24, 2019
2 parents 8d4f2bd + b6e7351 commit b12484b
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@

),

snapshotted_data as (

select *,
{{ strategy.unique_key }} as dbt_unique_key

from {{ target_relation }}

),

source_data as (

select *,
Expand All @@ -43,15 +52,6 @@
from snapshot_query
),

snapshotted_data as (

select *,
{{ strategy.unique_key }} as dbt_unique_key

from {{ target_relation }}

),

insertions as (

select
Expand Down Expand Up @@ -84,6 +84,15 @@

),

snapshotted_data as (

select *,
{{ strategy.unique_key }} as dbt_unique_key

from {{ target_relation }}

),

source_data as (

select
Expand All @@ -96,15 +105,6 @@
from snapshot_query
),

snapshotted_data as (

select *,
{{ strategy.unique_key }} as dbt_unique_key

from {{ target_relation }}

),

updates as (

select
Expand Down Expand Up @@ -202,7 +202,7 @@
{%- endif -%}

{% set strategy_macro = strategy_dispatch(strategy_name) %}
{% set strategy = strategy_macro(model, "snapshotted_data", "source_data", config) %}
{% set strategy = strategy_macro(model, "snapshotted_data", "source_data", config, target_relation_exists) %}

{% if not target_relation_exists %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,17 @@
{#
Create SCD Hash SQL fields cross-db
#}
{% macro snapshot_hash_arguments(args) %}
{% macro snapshot_hash_arguments(args) -%}
{{ adapter_macro('snapshot_hash_arguments', args) }}
{% endmacro %}
{%- endmacro %}


{% macro default__snapshot_hash_arguments(args) %}
md5({% for arg in args %}
coalesce(cast({{ arg }} as varchar ), '') {% if not loop.last %} || '|' || {% endif %}
{% endfor %})
{% endmacro %}
{% macro default__snapshot_hash_arguments(args) -%}
md5({%- for arg in args -%}
coalesce(cast({{ arg }} as varchar ), '')
{% if not loop.last %} || '|' || {% endif %}
{%- endfor -%})
{%- endmacro %}


{#
Expand All @@ -62,7 +63,7 @@
{#
Core strategy definitions
#}
{% macro snapshot_timestamp_strategy(node, snapshotted_rel, current_rel, config) %}
{% macro snapshot_timestamp_strategy(node, snapshotted_rel, current_rel, config, target_exists) %}
{% set primary_key = config['unique_key'] %}
{% set updated_at = config['updated_at'] %}

Expand All @@ -81,7 +82,7 @@
{% endmacro %}


{% macro snapshot_check_strategy(node, snapshotted_rel, current_rel, config) %}
{% macro snapshot_check_strategy(node, snapshotted_rel, current_rel, config, target_exists) %}
{% set check_cols_config = config['check_cols'] %}
{% set primary_key = config['unique_key'] %}
{% set updated_at = snapshot_get_time() %}
Expand All @@ -106,7 +107,18 @@
)
{%- endset %}

{% set scd_id_cols = [primary_key] + (check_cols | list) %}
{% if target_exists %}
{% set row_version -%}
(
select count(*) from {{ snapshotted_rel }}
where {{ snapshotted_rel }}.dbt_unique_key = {{ primary_key }}
)
{%- endset %}
{% set scd_id_cols = [primary_key, row_version] + (check_cols | list) %}
{% else %}
{% set scd_id_cols = [primary_key] + (check_cols | list) %}
{% endif %}

{% set scd_id_expr = snapshot_hash_arguments(scd_id_cols) %}

{% do return({
Expand Down
22 changes: 19 additions & 3 deletions plugins/bigquery/dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,31 @@ def raw_execute(self, sql, fetch=False):

def execute(self, sql, auto_begin=False, fetch=None):
# auto_begin is ignored on bigquery, and only included for consistency
_, iterator = self.raw_execute(sql, fetch=fetch)
query_job, iterator = self.raw_execute(sql, fetch=fetch)

if fetch:
res = self.get_table_from_response(iterator)
else:
res = dbt.clients.agate_helper.empty_table()

# If we get here, the query succeeded
status = 'OK'
if query_job.statement_type == 'CREATE_VIEW':
status = 'CREATE VIEW'

elif query_job.statement_type == 'CREATE_TABLE_AS_SELECT':
conn = self.get_thread_connection()
client = conn.handle
table = client.get_table(query_job.destination)
status = 'CREATE TABLE ({})'.format(table.num_rows)

elif query_job.statement_type in ['INSERT', 'DELETE', 'MERGE']:
status = '{} ({})'.format(
query_job.statement_type,
query_job.num_dml_affected_rows
)

else:
status = 'OK'

return status, res

def create_bigquery_table(self, database, schema, table_name, callback,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
{% macro bigquery__snapshot_hash_arguments(args) %}
to_hex(md5(concat({% for arg in args %}coalesce(cast({{ arg }} as string), ''){% if not loop.last %}, '|',{% endif %}{% endfor %})))
{% endmacro %}
{% macro bigquery__snapshot_hash_arguments(args) -%}
to_hex(md5(concat({%- for arg in args -%}
coalesce(cast({{ arg }} as string), ''){% if not loop.last %}, '|',{% endif -%}
{%- endfor -%}
)))
{%- endmacro %}

{% macro bigquery__create_columns(relation, columns) %}
{{ adapter.alter_table_add_columns(relation, columns) }}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@


with query as (

-- check that the current value for id=1 is red
select case when (
select count(*)
from {{ ref('check_cols_cycle') }}
where id = 1 and color = 'red' and dbt_valid_to is null
) = 1 then 0 else 1 end as failures

union all

-- check that the previous 'red' value for id=1 is invalidated
select case when (
select count(*)
from {{ ref('check_cols_cycle') }}
where id = 1 and color = 'red' and dbt_valid_to is not null
) = 1 then 0 else 1 end as failures

union all

-- check that there's only one current record for id=2
select case when (
select count(*)
from {{ ref('check_cols_cycle') }}
where id = 2 and color = 'pink' and dbt_valid_to is null
) = 1 then 0 else 1 end as failures

union all

-- check that the previous value for id=2 is represented
select case when (
select count(*)
from {{ ref('check_cols_cycle') }}
where id = 2 and color = 'green' and dbt_valid_to is not null
) = 1 then 0 else 1 end as failures

union all

-- check that there are 5 records total in the table
select case when (
select count(*)
from {{ ref('check_cols_cycle') }}
) = 5 then 0 else 1 end as failures

)

select *
from query
where failures = 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

{% snapshot check_cols_cycle %}

{{
config(
target_database=database,
target_schema=schema,
unique_key='id',
strategy='check',
check_cols=['color']
)
}}

{% if var('version') == 1 %}

select 1 as id, 'red' as color union all
select 2 as id, 'green' as color

{% elif var('version') == 2 %}

select 1 as id, 'blue' as color union all
select 2 as id, 'green' as color

{% elif var('version') == 3 %}

select 1 as id, 'red' as color union all
select 2 as id, 'pink' as color

{% else %}
{% do exceptions.raise_compiler_error("Got bad version: " ~ var('version')) %}
{% endif %}

{% endsnapshot %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from test.integration.base import DBTIntegrationTest, use_profile
import dbt.exceptions


class TestSimpleSnapshotFiles(DBTIntegrationTest):
NUM_SNAPSHOT_MODELS = 1

@property
def schema(self):
return "simple_snapshot_004"

@property
def models(self):
return "models"

@property
def project_config(self):
return {
"snapshot-paths": ['check-snapshots'],
"test-paths": ['check-snapshots-expected'],
"source-paths": [],
}

def test_snapshot_check_cols_cycle(self):
results = self.run_dbt(["snapshot", '--vars', 'version: 1'])
self.assertEqual(len(results), 1)

results = self.run_dbt(["snapshot", '--vars', 'version: 2'])
self.assertEqual(len(results), 1)

results = self.run_dbt(["snapshot", '--vars', 'version: 3'])
self.assertEqual(len(results), 1)

def assert_expected(self):
self.run_dbt(['test', '--data', '--vars', 'version: 3'])

@use_profile('snowflake')
def test__snowflake__simple_snapshot(self):
self.test_snapshot_check_cols_cycle()
self.assert_expected()

@use_profile('postgres')
def test__postgres__simple_snapshot(self):
self.test_snapshot_check_cols_cycle()
self.assert_expected()

@use_profile('bigquery')
def test__bigquery__simple_snapshot(self):
self.test_snapshot_check_cols_cycle()
self.assert_expected()

@use_profile('redshift')
def test__redshift__simple_snapshot(self):
self.test_snapshot_check_cols_cycle()
self.assert_expected()

0 comments on commit b12484b

Please sign in to comment.