diff --git a/clients/bigquery/bigquery.go b/clients/bigquery/bigquery.go index 9485b41f1..34943329d 100644 --- a/clients/bigquery/bigquery.go +++ b/clients/bigquery/bigquery.go @@ -68,8 +68,8 @@ func (s *Store) Append(ctx context.Context, tableData *optimization.TableData, u query := fmt.Sprintf(`INSERT INTO %s (%s) SELECT %s FROM %s`, tableID.FullyQualifiedName(), - strings.Join(sql.QuoteColumns(tableData.ReadOnlyInMemoryCols().ValidColumns(), s.Dialect()), ","), - strings.Join(sql.QuoteColumns(tableData.ReadOnlyInMemoryCols().ValidColumns(), s.Dialect()), ","), + strings.Join(sql.QuoteColumns(tableData.GetValidColumns(), s.Dialect()), ","), + strings.Join(sql.QuoteColumns(tableData.GetValidColumns(), s.Dialect()), ","), temporaryTableID.FullyQualifiedName(), ) @@ -133,7 +133,7 @@ func (s *Store) GetClient(ctx context.Context) *bigquery.Client { } func (s *Store) putTable(ctx context.Context, bqTableID dialect.TableIdentifier, tableData *optimization.TableData) error { - columns := tableData.ReadOnlyInMemoryCols().ValidColumns() + columns := tableData.GetValidColumns() messageDescriptor, err := columnsToMessageDescriptor(columns) if err != nil { diff --git a/clients/bigquery/dialect/dialect_test.go b/clients/bigquery/dialect/dialect_test.go index 06dbce1ee..548c3e9bf 100644 --- a/clients/bigquery/dialect/dialect_test.go +++ b/clients/bigquery/dialect/dialect_test.go @@ -159,11 +159,19 @@ func TestBigQueryDialect_BuildMergeQueries_SoftDelete(t *testing.T) { func TestBigQueryDialect_BuildMergeQueries_JSONKey(t *testing.T) { orderOIDCol := columns.NewColumn("order_oid", typing.Struct) - var cols columns.Columns - cols.AddColumn(orderOIDCol) - cols.AddColumn(columns.NewColumn("name", typing.String)) - cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) - cols.AddColumn(columns.NewColumn(constants.OnlySetDeleteColumnMarker, typing.Boolean)) + + var cols []columns.Column + cols = append(cols, orderOIDCol) + cols = append(cols, columns.NewColumn("name", typing.String)) + cols = append(cols, columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) + cols = append(cols, columns.NewColumn(constants.OnlySetDeleteColumnMarker, typing.Boolean)) + + var validCols []columns.Column + for _, col := range cols { + if col.IsValid() { + validCols = append(validCols, col) + } + } fakeTableID := &mocks.FakeTableIdentifier{} fakeTableID.FullyQualifiedNameReturns("customers.orders") @@ -173,7 +181,7 @@ func TestBigQueryDialect_BuildMergeQueries_JSONKey(t *testing.T) { "customers.orders_tmp", []columns.Column{orderOIDCol}, nil, - cols.ValidColumns(), + validCols, false, false, ) diff --git a/clients/databricks/dialect/dialect_test.go b/clients/databricks/dialect/dialect_test.go index 59c0c6b4d..b6f70e28f 100644 --- a/clients/databricks/dialect/dialect_test.go +++ b/clients/databricks/dialect/dialect_test.go @@ -127,7 +127,7 @@ func TestDatabricksDialect_BuildDedupeQueries(t *testing.T) { } } -func buildColumns(colTypesMap map[string]typing.KindDetails) *columns.Columns { +func buildColumns(colTypesMap map[string]typing.KindDetails) []columns.Column { var colNames []string for colName := range colTypesMap { colNames = append(colNames, colName) @@ -135,12 +135,12 @@ func buildColumns(colTypesMap map[string]typing.KindDetails) *columns.Columns { // Sort the column names alphabetically to ensure deterministic order slices.Sort(colNames) - var cols columns.Columns + var cols []columns.Column for _, colName := range colNames { - cols.AddColumn(columns.NewColumn(colName, colTypesMap[colName])) + cols = append(cols, columns.NewColumn(colName, colTypesMap[colName])) } - return &cols + return cols } func TestDatabricksDialect_BuildMergeQueries_SoftDelete(t *testing.T) { @@ -156,13 +156,20 @@ func TestDatabricksDialect_BuildMergeQueries_SoftDelete(t *testing.T) { fakeTableID := &mocks.FakeTableIdentifier{} fakeTableID.FullyQualifiedNameReturns(fqTable) + var validCols []columns.Column + for _, col := range _cols { + if col.IsValid() { + validCols = append(validCols, col) + } + } + { statements, err := DatabricksDialect{}.BuildMergeQueries( fakeTableID, fqTable, []columns.Column{columns.NewColumn("id", typing.Invalid)}, nil, - _cols.ValidColumns(), + validCols, true, false, ) @@ -193,12 +200,19 @@ func TestDatabricksDialect_BuildMergeQueries(t *testing.T) { fakeTableID := &mocks.FakeTableIdentifier{} fakeTableID.FullyQualifiedNameReturns(fqTable) + var validCols []columns.Column + for _, col := range _cols { + if col.IsValid() { + validCols = append(validCols, col) + } + } + statements, err := DatabricksDialect{}.BuildMergeQueries( fakeTableID, fqTable, []columns.Column{columns.NewColumn("id", typing.Invalid)}, nil, - _cols.ValidColumns(), + validCols, false, false, ) @@ -225,6 +239,13 @@ func TestDatabricksDialect_BuildMergeQueries_CompositeKey(t *testing.T) { fakeTableID := &mocks.FakeTableIdentifier{} fakeTableID.FullyQualifiedNameReturns(fqTable) + var validCols []columns.Column + for _, col := range _cols { + if col.IsValid() { + validCols = append(validCols, col) + } + } + statements, err := DatabricksDialect{}.BuildMergeQueries( fakeTableID, fqTable, @@ -233,7 +254,7 @@ func TestDatabricksDialect_BuildMergeQueries_CompositeKey(t *testing.T) { columns.NewColumn("another_id", typing.Invalid), }, nil, - _cols.ValidColumns(), + validCols, false, false, ) @@ -262,6 +283,13 @@ func TestDatabricksDialect_BuildMergeQueries_EscapePrimaryKeys(t *testing.T) { fakeTableID := &mocks.FakeTableIdentifier{} fakeTableID.FullyQualifiedNameReturns(fqTable) + var validCols []columns.Column + for _, col := range _cols { + if col.IsValid() { + validCols = append(validCols, col) + } + } + statements, err := DatabricksDialect{}.BuildMergeQueries( fakeTableID, fqTable, @@ -270,7 +298,7 @@ func TestDatabricksDialect_BuildMergeQueries_EscapePrimaryKeys(t *testing.T) { columns.NewColumn("group", typing.Invalid), }, nil, - _cols.ValidColumns(), + validCols, false, false, ) diff --git a/clients/databricks/store.go b/clients/databricks/store.go index 74d86b2c7..dc8795917 100644 --- a/clients/databricks/store.go +++ b/clients/databricks/store.go @@ -154,7 +154,7 @@ func (s Store) writeTemporaryTableFile(tableData *optimization.TableData, newTab writer := csv.NewWriter(file) writer.Comma = '\t' - columns := tableData.ReadOnlyInMemoryCols().ValidColumns() + columns := tableData.GetValidColumns() for _, value := range tableData.Rows() { var row []string for _, col := range columns { diff --git a/clients/mssql/staging.go b/clients/mssql/staging.go index 1a1272c07..236789732 100644 --- a/clients/mssql/staging.go +++ b/clients/mssql/staging.go @@ -32,7 +32,7 @@ func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimizati } }() - cols := tableData.ReadOnlyInMemoryCols().ValidColumns() + cols := tableData.GetValidColumns() stmt, err := tx.Prepare(mssql.CopyIn(tempTableID.FullyQualifiedName(), mssql.BulkOptions{}, columns.ColumnNames(cols)...)) if err != nil { return fmt.Errorf("failed to prepare bulk insert: %w", err) diff --git a/clients/redshift/dialect/dialect_test.go b/clients/redshift/dialect/dialect_test.go index 1e7bb8992..306686c2c 100644 --- a/clients/redshift/dialect/dialect_test.go +++ b/clients/redshift/dialect/dialect_test.go @@ -168,15 +168,15 @@ func getBasicColumnsForTest(compositeKey bool) result { textToastCol := columns.NewColumn("toast_text", typing.String) textToastCol.ToastColumn = true - var cols columns.Columns - cols.AddColumn(idCol) - cols.AddColumn(emailCol) - cols.AddColumn(columns.NewColumn("first_name", typing.String)) - cols.AddColumn(columns.NewColumn("last_name", typing.String)) - cols.AddColumn(columns.NewColumn("created_at", typing.TimestampNTZ)) - cols.AddColumn(textToastCol) - cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) - cols.AddColumn(columns.NewColumn(constants.OnlySetDeleteColumnMarker, typing.Boolean)) + var cols []columns.Column + cols = append(cols, idCol) + cols = append(cols, emailCol) + cols = append(cols, columns.NewColumn("first_name", typing.String)) + cols = append(cols, columns.NewColumn("last_name", typing.String)) + cols = append(cols, columns.NewColumn("created_at", typing.TimestampNTZ)) + cols = append(cols, textToastCol) + cols = append(cols, columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) + cols = append(cols, columns.NewColumn(constants.OnlySetDeleteColumnMarker, typing.Boolean)) var pks []columns.Column pks = append(pks, idCol) @@ -185,9 +185,16 @@ func getBasicColumnsForTest(compositeKey bool) result { pks = append(pks, emailCol) } + var validCols []columns.Column + for _, col := range cols { + if col.IsValid() { + validCols = append(validCols, col) + } + } + return result{ PrimaryKeys: pks, - Columns: cols.ValidColumns(), + Columns: validCols, } } diff --git a/clients/redshift/staging.go b/clients/redshift/staging.go index 6bc298ad8..3d1b37cf8 100644 --- a/clients/redshift/staging.go +++ b/clients/redshift/staging.go @@ -69,7 +69,7 @@ func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimizati copyStmt := fmt.Sprintf( `COPY %s (%s) FROM '%s' DELIMITER '\t' NULL AS '\\N' GZIP FORMAT CSV %s dateformat 'auto' timeformat 'auto';`, tempTableID.FullyQualifiedName(), - strings.Join(sql.QuoteColumns(tableData.ReadOnlyInMemoryCols().ValidColumns(), s.Dialect()), ","), + strings.Join(sql.QuoteColumns(tableData.GetValidColumns(), s.Dialect()), ","), s3Uri, s.credentialsClause, ) @@ -94,7 +94,7 @@ func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableID writer := csv.NewWriter(gzipWriter) writer.Comma = '\t' - _columns := tableData.ReadOnlyInMemoryCols().ValidColumns() + _columns := tableData.GetValidColumns() columnToNewLengthMap := make(map[string]int32) for _, value := range tableData.Rows() { var row []string diff --git a/clients/s3/s3.go b/clients/s3/s3.go index f0b10be9e..cc50b1e60 100644 --- a/clients/s3/s3.go +++ b/clients/s3/s3.go @@ -77,7 +77,7 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er return nil } - cols := tableData.ReadOnlyInMemoryCols().ValidColumns() + cols := tableData.GetValidColumns() schema, err := parquetutil.BuildCSVSchema(cols) if err != nil { return fmt.Errorf("failed to generate parquet schema: %w", err) diff --git a/clients/shared/merge.go b/clients/shared/merge.go index 459d8faa1..8f9392d19 100644 --- a/clients/shared/merge.go +++ b/clients/shared/merge.go @@ -128,15 +128,10 @@ func Merge(ctx context.Context, dwh destination.DataWarehouse, tableData *optimi return fmt.Errorf("primary keys cannot be empty") } - validColumns := cols.ValidColumns() + validColumns := tableData.GetValidColumns() if len(validColumns) == 0 { return fmt.Errorf("columns cannot be empty") } - for _, column := range validColumns { - if !column.IsValid() { - return fmt.Errorf("column %q is invalid and should be skipped", column.Name()) - } - } mergeStatements, err := dwh.Dialect().BuildMergeQueries( tableID, diff --git a/clients/snowflake/dialect/dialect_test.go b/clients/snowflake/dialect/dialect_test.go index f06e07e85..3452a4d73 100644 --- a/clients/snowflake/dialect/dialect_test.go +++ b/clients/snowflake/dialect/dialect_test.go @@ -77,20 +77,20 @@ func TestSnowflakeDialect_BuildIsNotToastValueExpression(t *testing.T) { ) } -func buildColumns(colTypesMap map[string]typing.KindDetails) *columns.Columns { - colNames := []string{} +func buildColumns(colTypesMap map[string]typing.KindDetails) []columns.Column { + var colNames []string for colName := range colTypesMap { colNames = append(colNames, colName) } // Sort the column names alphabetically to ensure deterministic order slices.Sort(colNames) - var cols columns.Columns + var cols []columns.Column for _, colName := range colNames { - cols.AddColumn(columns.NewColumn(colName, colTypesMap[colName])) + cols = append(cols, columns.NewColumn(colName, colTypesMap[colName])) } - return &cols + return cols } func TestSnowflakeDialect_BuildMergeQueries_SoftDelete(t *testing.T) { @@ -106,13 +106,20 @@ func TestSnowflakeDialect_BuildMergeQueries_SoftDelete(t *testing.T) { fakeTableID := &mocks.FakeTableIdentifier{} fakeTableID.FullyQualifiedNameReturns(fqTable) + var validCols []columns.Column + for _, col := range _cols { + if col.IsValid() { + validCols = append(validCols, col) + } + } + { statements, err := SnowflakeDialect{}.BuildMergeQueries( fakeTableID, fqTable, []columns.Column{columns.NewColumn("id", typing.Invalid)}, nil, - _cols.ValidColumns(), + validCols, true, false, ) @@ -140,12 +147,19 @@ func TestSnowflakeDialect_BuildMergeQueries(t *testing.T) { fakeTableID := &mocks.FakeTableIdentifier{} fakeTableID.FullyQualifiedNameReturns(fqTable) + var validCols []columns.Column + for _, col := range _cols { + if col.IsValid() { + validCols = append(validCols, col) + } + } + statements, err := SnowflakeDialect{}.BuildMergeQueries( fakeTableID, fqTable, []columns.Column{columns.NewColumn("id", typing.Invalid)}, nil, - _cols.ValidColumns(), + validCols, false, false, ) @@ -170,6 +184,13 @@ func TestSnowflakeDialect_BuildMergeQueries_CompositeKey(t *testing.T) { fakeTableID := &mocks.FakeTableIdentifier{} fakeTableID.FullyQualifiedNameReturns(fqTable) + var validCols []columns.Column + for _, col := range _cols { + if col.IsValid() { + validCols = append(validCols, col) + } + } + statements, err := SnowflakeDialect{}.BuildMergeQueries( fakeTableID, fqTable, @@ -178,7 +199,7 @@ func TestSnowflakeDialect_BuildMergeQueries_CompositeKey(t *testing.T) { columns.NewColumn("another_id", typing.Invalid), }, nil, - _cols.ValidColumns(), + validCols, false, false, ) @@ -205,6 +226,13 @@ func TestSnowflakeDialect_BuildMergeQueries_EscapePrimaryKeys(t *testing.T) { fakeTableID := &mocks.FakeTableIdentifier{} fakeTableID.FullyQualifiedNameReturns(fqTable) + var validCols []columns.Column + for _, col := range _cols { + if col.IsValid() { + validCols = append(validCols, col) + } + } + statements, err := SnowflakeDialect{}.BuildMergeQueries( fakeTableID, fqTable, @@ -213,7 +241,7 @@ func TestSnowflakeDialect_BuildMergeQueries_EscapePrimaryKeys(t *testing.T) { columns.NewColumn("group", typing.Invalid), }, nil, - _cols.ValidColumns(), + validCols, false, false, ) diff --git a/clients/snowflake/staging.go b/clients/snowflake/staging.go index c1ee740ab..3baa63d6d 100644 --- a/clients/snowflake/staging.go +++ b/clients/snowflake/staging.go @@ -78,7 +78,7 @@ func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimizati // COPY the CSV file (in Snowflake) into a table copyCommand := fmt.Sprintf("COPY INTO %s (%s) FROM (SELECT %s FROM @%s)", tempTableID.FullyQualifiedName(), - strings.Join(sql.QuoteColumns(tableData.ReadOnlyInMemoryCols().ValidColumns(), s.Dialect()), ","), + strings.Join(sql.QuoteColumns(tableData.GetValidColumns(), s.Dialect()), ","), escapeColumns(tableData.ReadOnlyInMemoryCols(), ","), addPrefixToTableName(tempTableID, "%")) if additionalSettings.AdditionalCopyClause != "" { @@ -103,7 +103,7 @@ func (s *Store) writeTemporaryTableFile(tableData *optimization.TableData, newTa writer := csv.NewWriter(file) writer.Comma = '\t' - columns := tableData.ReadOnlyInMemoryCols().ValidColumns() + columns := tableData.GetValidColumns() for _, row := range tableData.Rows() { var csvRow []string for _, col := range columns { diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index 943995fe7..cfe7ca56a 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -204,26 +204,6 @@ func (c *Columns) GetColumn(name string) (Column, bool) { return Column{}, false } -// ValidColumns will filter all the `Invalid` columns so that we do not update them. -// This is used mostly for the SQL MERGE queries. -func (c *Columns) ValidColumns() []Column { - if c == nil { - return []Column{} - } - - c.RLock() - defer c.RUnlock() - - var cols []Column - for _, col := range c.columns { - if col.IsValid() { - cols = append(cols, col) - } - } - - return cols -} - func (c *Columns) GetColumns() []Column { if c == nil { return []Column{} diff --git a/lib/typing/columns/columns_test.go b/lib/typing/columns/columns_test.go index 611791084..7c1b5823a 100644 --- a/lib/typing/columns/columns_test.go +++ b/lib/typing/columns/columns_test.go @@ -1,8 +1,6 @@ package columns import ( - "fmt" - "slices" "testing" "github.com/artie-labs/transfer/lib/config/constants" @@ -104,52 +102,6 @@ func TestColumn_ShouldBackfill(t *testing.T) { } } -func TestColumns_ValidColumns(t *testing.T) { - var happyPathCols = []Column{ - { - name: "hi", - KindDetails: typing.String, - }, - { - name: "bye", - KindDetails: typing.String, - }, - { - name: "start", - KindDetails: typing.String, - }, - } - - extraCols := happyPathCols - for i := 0; i < 100; i++ { - extraCols = append(extraCols, Column{ - name: fmt.Sprintf("hello_%v", i), - KindDetails: typing.Invalid, - }) - } - - testCases := []struct { - name string - cols []Column - expectedCols []Column - }{ - { - name: "happy path", - cols: happyPathCols, - expectedCols: slices.Clone(happyPathCols), - }, - { - name: "happy path + extra col", - cols: extraCols, - expectedCols: slices.Clone(happyPathCols), - }, - } - - for _, testCase := range testCases { - assert.Equal(t, testCase.expectedCols, (&Columns{columns: testCase.cols}).ValidColumns(), testCase.name) - } -} - func TestColumns_UpsertColumns(t *testing.T) { keys := []string{"a", "b", "c", "d", "e"} var cols Columns