diff --git a/clients/shared/table.go b/clients/shared/table.go index a3bde8c41..3a1fbeac7 100644 --- a/clients/shared/table.go +++ b/clients/shared/table.go @@ -30,21 +30,34 @@ func CreateTable(ctx context.Context, dwh destination.DataWarehouse, tableData * return nil } -func AlterTableAddColumns(ctx context.Context, dwh destination.DataWarehouse, tc *types.DwhTableConfig, tableID sql.TableIdentifier, columns []columns.Column) error { - if len(columns) == 0 { +func AlterTableAddColumns(ctx context.Context, dwh destination.DataWarehouse, tc *types.DwhTableConfig, tableID sql.TableIdentifier, cols []columns.Column) error { + if len(cols) == 0 { return nil } - sqlParts, addedCols := ddl.BuildAlterTableAddColumns(dwh.Dialect(), tableID, columns) + var colsToAdd []columns.Column + for _, col := range cols { + if col.ShouldSkip() { + continue + } + + colsToAdd = append(colsToAdd, col) + } + + sqlParts, err := ddl.BuildAlterTableAddColumns(dwh.Dialect(), tableID, colsToAdd) + if err != nil { + return fmt.Errorf("failed to build alter table add columns: %w", err) + } + for _, sqlPart := range sqlParts { slog.Info("[DDL] Executing query", slog.String("query", sqlPart)) - if _, err := dwh.ExecContext(ctx, sqlPart); err != nil { + if _, err = dwh.ExecContext(ctx, sqlPart); err != nil { if !dwh.Dialect().IsColumnAlreadyExistsErr(err) { return fmt.Errorf("failed to alter table: %w", err) } } } - tc.MutateInMemoryColumns(constants.Add, addedCols...) + tc.MutateInMemoryColumns(constants.Add, colsToAdd...) return nil } diff --git a/lib/destination/ddl/ddl.go b/lib/destination/ddl/ddl.go index caf20a477..8546f4a36 100644 --- a/lib/destination/ddl/ddl.go +++ b/lib/destination/ddl/ddl.go @@ -69,20 +69,18 @@ func DropTemporaryTable(dwh destination.DataWarehouse, tableIdentifier sql.Table return nil } -func BuildAlterTableAddColumns(dialect sql.Dialect, tableID sql.TableIdentifier, cols []columns.Column) ([]string, []columns.Column) { +func BuildAlterTableAddColumns(dialect sql.Dialect, tableID sql.TableIdentifier, cols []columns.Column) ([]string, error) { var parts []string - var addedCols []columns.Column for _, col := range cols { if col.ShouldSkip() { - continue + return nil, fmt.Errorf("received an invalid column %q", col.Name()) } sqlPart := fmt.Sprintf("%s %s", dialect.QuoteIdentifier(col.Name()), dialect.DataTypeForKind(col.KindDetails, col.PrimaryKey())) parts = append(parts, dialect.BuildAlterColumnQuery(tableID, constants.Add, sqlPart)) - addedCols = append(addedCols, col) } - return parts, addedCols + return parts, nil } type AlterTableArgs struct { diff --git a/lib/destination/ddl/ddl_test.go b/lib/destination/ddl/ddl_test.go index 55e16cf67..102dfc4dc 100644 --- a/lib/destination/ddl/ddl_test.go +++ b/lib/destination/ddl/ddl_test.go @@ -79,34 +79,29 @@ func TestBuildCreateTableSQL(t *testing.T) { func TestBuildAlterTableAddColumns(t *testing.T) { { // No columns - sqlParts, addedCols := BuildAlterTableAddColumns(nil, nil, []columns.Column{}) + sqlParts, err := BuildAlterTableAddColumns(nil, nil, []columns.Column{}) + assert.NoError(t, err) assert.Empty(t, sqlParts) - assert.Empty(t, addedCols) } { // One column to add col := columns.NewColumn("dusty", typing.String) - sqlParts, addedCols := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{col}) + sqlParts, err := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{col}) + assert.NoError(t, err) assert.Len(t, sqlParts, 1) assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "dusty" VARCHAR(MAX)`, sqlParts[0]) - - assert.Len(t, addedCols, 1) - assert.Equal(t, col, addedCols[0]) } { - // Two columns, it skips the invalid column + // Two columns, one invalid, it will error. col := columns.NewColumn("dusty", typing.String) - sqlParts, addedCols := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), + _, err := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{ col, columns.NewColumn("invalid", typing.Invalid), }, ) - assert.Len(t, sqlParts, 1) - assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "dusty" VARCHAR(MAX)`, sqlParts[0]) - assert.Len(t, addedCols, 1) - assert.Equal(t, col, addedCols[0]) + assert.ErrorContains(t, err, `received an invalid column "invalid"`) } { // Three columns to add @@ -114,16 +109,12 @@ func TestBuildAlterTableAddColumns(t *testing.T) { col2 := columns.NewColumn("doge", typing.String) col3 := columns.NewColumn("age", typing.Integer) - sqlParts, addedCols := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{col1, col2, col3}) + sqlParts, err := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{col1, col2, col3}) + assert.NoError(t, err) assert.Len(t, sqlParts, 3) assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "aussie" VARCHAR(MAX)`, sqlParts[0]) assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "doge" VARCHAR(MAX)`, sqlParts[1]) assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "age" INT8`, sqlParts[2]) - - assert.Len(t, addedCols, 3) - assert.Equal(t, col1, addedCols[0]) - assert.Equal(t, col2, addedCols[1]) - assert.Equal(t, col3, addedCols[2]) } }