diff --git a/lib/destination/ddl/ddl.go b/lib/destination/ddl/ddl.go index 5bcb48c5e..ea1681933 100644 --- a/lib/destination/ddl/ddl.go +++ b/lib/destination/ddl/ddl.go @@ -76,6 +76,10 @@ func (a AlterTableArgs) Validate() error { return nil } +func shouldCreatePrimaryKey(col columns.Column, mode config.Mode, createTable bool) bool { + return col.PrimaryKey() && mode == config.Replication && createTable +} + func (a AlterTableArgs) buildStatements(cols ...columns.Column) ([]string, []columns.Column) { var mutateCol []columns.Column // It's okay to combine since args.ColumnOp only takes one of: `Delete` or `Add` @@ -97,13 +101,11 @@ func (a AlterTableArgs) buildStatements(cols ...columns.Column) ([]string, []col switch a.ColumnOp { case constants.Add: colName := a.Dialect.QuoteIdentifier(col.Name()) - - if col.PrimaryKey() && a.Mode != config.History { - // Don't create a PK for history mode because it's append-only, so the primary key should not be enforced. + if shouldCreatePrimaryKey(col, a.Mode, a.CreateTable) { pkCols = append(pkCols, colName) } - colSQLParts = append(colSQLParts, fmt.Sprintf(`%s %s`, colName, a.Dialect.DataTypeForKind(col.KindDetails, col.PrimaryKey()))) + colSQLParts = append(colSQLParts, fmt.Sprintf("%s %s", colName, a.Dialect.DataTypeForKind(col.KindDetails, col.PrimaryKey()))) case constants.Delete: colSQLParts = append(colSQLParts, a.Dialect.QuoteIdentifier(col.Name())) } @@ -140,7 +142,6 @@ func (a AlterTableArgs) AlterTable(dwh destination.DataWarehouse, cols ...column } alterStatements, mutateCol := a.buildStatements(cols...) - for _, sqlQuery := range alterStatements { slog.Info("DDL - executing sql", slog.String("query", sqlQuery)) if _, err := dwh.Exec(sqlQuery); err != nil { diff --git a/lib/destination/ddl/ddl_temp_test.go b/lib/destination/ddl/ddl_temp_test.go index a314b2c07..2c142a8eb 100644 --- a/lib/destination/ddl/ddl_temp_test.go +++ b/lib/destination/ddl/ddl_temp_test.go @@ -1,6 +1,7 @@ package ddl_test import ( + "fmt" "time" "github.com/stretchr/testify/assert" @@ -9,8 +10,11 @@ import ( "github.com/artie-labs/transfer/clients/snowflake/dialect" "github.com/artie-labs/transfer/lib/config" "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/destination" "github.com/artie-labs/transfer/lib/destination/ddl" "github.com/artie-labs/transfer/lib/destination/types" + "github.com/artie-labs/transfer/lib/kafkalib" + "github.com/artie-labs/transfer/lib/mocks" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" ) @@ -112,3 +116,114 @@ func (d *DDLTestSuite) TestCreateTemporaryTable() { assert.Contains(d.T(), bqQuery, "CREATE TABLE IF NOT EXISTS `db`.`schema`.`tempTableName` (`foo` string,`bar` float64,`select` string) OPTIONS (expiration_timestamp =") } } + +func (d *DDLTestSuite) Test_DropTemporaryTableCaseSensitive() { + tablesToDrop := []string{ + "foo", + "abcdef", + "gghh", + } + + for i, dest := range []destination.DataWarehouse{d.bigQueryStore, d.snowflakeStagesStore} { + var fakeStore *mocks.FakeStore + if i == 0 { + fakeStore = d.fakeBigQueryStore + } else { + fakeStore = d.fakeSnowflakeStagesStore + } + + for tableIndex, table := range tablesToDrop { + tableIdentifier := dest.IdentifierFor(kafkalib.TopicConfig{}, fmt.Sprintf("%s_%s", table, constants.ArtiePrefix)) + _ = ddl.DropTemporaryTable(dest, tableIdentifier, false) + + // There should be the same number of DROP table calls as the number of tables to drop. + assert.Equal(d.T(), tableIndex+1, fakeStore.ExecCallCount()) + query, _ := fakeStore.ExecArgsForCall(tableIndex) + assert.Equal(d.T(), fmt.Sprintf("DROP TABLE IF EXISTS %s", tableIdentifier.FullyQualifiedName()), query) + } + } +} + +func (d *DDLTestSuite) Test_DropTemporaryTable() { + doNotDropTables := []string{ + "foo", + "bar", + "abcd", + "customers.customers", + } + + // Should not drop since these do not have Artie prefix in the name. + for _, table := range doNotDropTables { + tableID := d.bigQueryStore.IdentifierFor(kafkalib.TopicConfig{}, table) + _ = ddl.DropTemporaryTable(d.snowflakeStagesStore, tableID, false) + assert.Equal(d.T(), 0, d.fakeSnowflakeStagesStore.ExecCallCount()) + } + + for i, _dwh := range []destination.DataWarehouse{d.bigQueryStore, d.snowflakeStagesStore} { + var fakeStore *mocks.FakeStore + if i == 0 { + fakeStore = d.fakeBigQueryStore + } else { + fakeStore = d.fakeSnowflakeStagesStore + + } + + for _, doNotDropTable := range doNotDropTables { + doNotDropTableID := d.bigQueryStore.IdentifierFor(kafkalib.TopicConfig{}, doNotDropTable) + _ = ddl.DropTemporaryTable(_dwh, doNotDropTableID, false) + + assert.Equal(d.T(), 0, fakeStore.ExecCallCount()) + } + + for index, table := range doNotDropTables { + fullTableID := d.bigQueryStore.IdentifierFor(kafkalib.TopicConfig{}, fmt.Sprintf("%s_%s", table, constants.ArtiePrefix)) + _ = ddl.DropTemporaryTable(_dwh, fullTableID, false) + + count := index + 1 + assert.Equal(d.T(), count, fakeStore.ExecCallCount()) + + query, _ := fakeStore.ExecArgsForCall(index) + assert.Equal(d.T(), fmt.Sprintf("DROP TABLE IF EXISTS %s", fullTableID.FullyQualifiedName()), query) + } + } +} + +func (d *DDLTestSuite) Test_DropTemporaryTable_Errors() { + tablesToDrop := []string{ + "foo", + "bar", + "abcd", + "customers.customers", + } + + randomErr := fmt.Errorf("random err") + for i, _dwh := range []destination.DataWarehouse{d.bigQueryStore, d.snowflakeStagesStore} { + var fakeStore *mocks.FakeStore + if i == 0 { + fakeStore = d.fakeBigQueryStore + d.fakeBigQueryStore.ExecReturns(nil, randomErr) + } else { + fakeStore = d.fakeSnowflakeStagesStore + d.fakeSnowflakeStagesStore.ExecReturns(nil, randomErr) + } + + var count int + for _, shouldReturnErr := range []bool{true, false} { + for _, table := range tablesToDrop { + tableID := d.bigQueryStore.IdentifierFor(kafkalib.TopicConfig{}, fmt.Sprintf("%s_%s", table, constants.ArtiePrefix)) + err := ddl.DropTemporaryTable(_dwh, tableID, shouldReturnErr) + if shouldReturnErr { + assert.ErrorContains(d.T(), err, randomErr.Error()) + } else { + assert.NoError(d.T(), err) + } + + count += 1 + assert.Equal(d.T(), count, fakeStore.ExecCallCount()) + query, _ := fakeStore.ExecArgsForCall(count - 1) + assert.Equal(d.T(), fmt.Sprintf("DROP TABLE IF EXISTS %s", tableID.FullyQualifiedName()), query) + } + } + + } +} diff --git a/lib/destination/ddl/ddl_test.go b/lib/destination/ddl/ddl_test.go index 6ac9c276c..1d514efa5 100644 --- a/lib/destination/ddl/ddl_test.go +++ b/lib/destination/ddl/ddl_test.go @@ -1,123 +1,41 @@ -package ddl_test +package ddl import ( - "fmt" + "testing" - "github.com/artie-labs/transfer/lib/config/constants" - "github.com/artie-labs/transfer/lib/destination" - "github.com/artie-labs/transfer/lib/destination/ddl" - "github.com/artie-labs/transfer/lib/kafkalib" - "github.com/artie-labs/transfer/lib/mocks" "github.com/stretchr/testify/assert" -) -func (d *DDLTestSuite) Test_DropTemporaryTableCaseSensitive() { - tablesToDrop := []string{ - "foo", - "abcdef", - "gghh", - } + "github.com/artie-labs/transfer/lib/config" + "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/columns" +) - for i, dest := range []destination.DataWarehouse{d.bigQueryStore, d.snowflakeStagesStore} { - var fakeStore *mocks.FakeStore - if i == 0 { - fakeStore = d.fakeBigQueryStore - } else { - fakeStore = d.fakeSnowflakeStagesStore +func TestShouldCreatePrimaryKey(t *testing.T) { + pk := columns.NewColumn("foo", typing.String) + pk.SetPrimaryKeyForTest(true) + { + // Primary key check + { + // Column is not a primary key + col := columns.NewColumn("foo", typing.String) + assert.False(t, shouldCreatePrimaryKey(col, config.Replication, true)) } - - for tableIndex, table := range tablesToDrop { - tableIdentifier := dest.IdentifierFor(kafkalib.TopicConfig{}, fmt.Sprintf("%s_%s", table, constants.ArtiePrefix)) - _ = ddl.DropTemporaryTable(dest, tableIdentifier, false) - - // There should be the same number of DROP table calls as the number of tables to drop. - assert.Equal(d.T(), tableIndex+1, fakeStore.ExecCallCount()) - query, _ := fakeStore.ExecArgsForCall(tableIndex) - assert.Equal(d.T(), fmt.Sprintf("DROP TABLE IF EXISTS %s", tableIdentifier.FullyQualifiedName()), query) + { + // Column is a primary key + assert.True(t, shouldCreatePrimaryKey(pk, config.Replication, true)) } } -} - -func (d *DDLTestSuite) Test_DropTemporaryTable() { - doNotDropTables := []string{ - "foo", - "bar", - "abcd", - "customers.customers", - } - - // Should not drop since these do not have Artie prefix in the name. - for _, table := range doNotDropTables { - tableID := d.bigQueryStore.IdentifierFor(kafkalib.TopicConfig{}, table) - _ = ddl.DropTemporaryTable(d.snowflakeStagesStore, tableID, false) - assert.Equal(d.T(), 0, d.fakeSnowflakeStagesStore.ExecCallCount()) - } - - for i, _dwh := range []destination.DataWarehouse{d.bigQueryStore, d.snowflakeStagesStore} { - var fakeStore *mocks.FakeStore - if i == 0 { - fakeStore = d.fakeBigQueryStore - } else { - fakeStore = d.fakeSnowflakeStagesStore - - } - - for _, doNotDropTable := range doNotDropTables { - doNotDropTableID := d.bigQueryStore.IdentifierFor(kafkalib.TopicConfig{}, doNotDropTable) - _ = ddl.DropTemporaryTable(_dwh, doNotDropTableID, false) - - assert.Equal(d.T(), 0, fakeStore.ExecCallCount()) - } - - for index, table := range doNotDropTables { - fullTableID := d.bigQueryStore.IdentifierFor(kafkalib.TopicConfig{}, fmt.Sprintf("%s_%s", table, constants.ArtiePrefix)) - _ = ddl.DropTemporaryTable(_dwh, fullTableID, false) - - count := index + 1 - assert.Equal(d.T(), count, fakeStore.ExecCallCount()) - - query, _ := fakeStore.ExecArgsForCall(index) - assert.Equal(d.T(), fmt.Sprintf("DROP TABLE IF EXISTS %s", fullTableID.FullyQualifiedName()), query) - } + { + // False because it's history mode + // It should be false because we are appending rows to this table. + assert.False(t, shouldCreatePrimaryKey(pk, config.History, true)) } -} - -func (d *DDLTestSuite) Test_DropTemporaryTable_Errors() { - tablesToDrop := []string{ - "foo", - "bar", - "abcd", - "customers.customers", + { + // False because it's not a create table operation + assert.False(t, shouldCreatePrimaryKey(pk, config.Replication, false)) } - - randomErr := fmt.Errorf("random err") - for i, _dwh := range []destination.DataWarehouse{d.bigQueryStore, d.snowflakeStagesStore} { - var fakeStore *mocks.FakeStore - if i == 0 { - fakeStore = d.fakeBigQueryStore - d.fakeBigQueryStore.ExecReturns(nil, randomErr) - } else { - fakeStore = d.fakeSnowflakeStagesStore - d.fakeSnowflakeStagesStore.ExecReturns(nil, randomErr) - } - - var count int - for _, shouldReturnErr := range []bool{true, false} { - for _, table := range tablesToDrop { - tableID := d.bigQueryStore.IdentifierFor(kafkalib.TopicConfig{}, fmt.Sprintf("%s_%s", table, constants.ArtiePrefix)) - err := ddl.DropTemporaryTable(_dwh, tableID, shouldReturnErr) - if shouldReturnErr { - assert.ErrorContains(d.T(), err, randomErr.Error()) - } else { - assert.NoError(d.T(), err) - } - - count += 1 - assert.Equal(d.T(), count, fakeStore.ExecCallCount()) - query, _ := fakeStore.ExecArgsForCall(count - 1) - assert.Equal(d.T(), fmt.Sprintf("DROP TABLE IF EXISTS %s", tableID.FullyQualifiedName()), query) - } - } - + { + // True because it's a primary key, replication mode, and create table operation + assert.True(t, shouldCreatePrimaryKey(pk, config.Replication, true)) } } diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index d9d5fad78..bb9cb1941 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -64,7 +64,6 @@ func (c *Column) SetBackfilled(backfilled bool) { func (c *Column) Backfilled() bool { return c.backfilled } - func (c *Column) SetDefaultValue(value any) { c.defaultValue = value } @@ -73,6 +72,10 @@ func (c *Column) ToLowerName() { c.name = strings.ToLower(c.name) } +func (c *Column) SetPrimaryKeyForTest(primaryKey bool) { + c.primaryKey = primaryKey +} + func (c *Column) ShouldBackfill() bool { if c.primaryKey { // Never need to backfill primary key.