Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 committed Nov 26, 2024
1 parent 1a09c91 commit 87e6d9a
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 117 deletions.
6 changes: 3 additions & 3 deletions clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ func (s *Store) Append(ctx context.Context, tableData *optimization.TableData, u

query := fmt.Sprintf(`INSERT INTO %s (%s) SELECT %s FROM %s`,
tableID.FullyQualifiedName(),
strings.Join(sql.QuoteColumns(tableData.ReadOnlyInMemoryCols().ValidColumns(), s.Dialect()), ","),
strings.Join(sql.QuoteColumns(tableData.ReadOnlyInMemoryCols().ValidColumns(), s.Dialect()), ","),
strings.Join(sql.QuoteColumns(tableData.GetValidColumns(), s.Dialect()), ","),
strings.Join(sql.QuoteColumns(tableData.GetValidColumns(), s.Dialect()), ","),
temporaryTableID.FullyQualifiedName(),
)

Expand Down Expand Up @@ -133,7 +133,7 @@ func (s *Store) GetClient(ctx context.Context) *bigquery.Client {
}

func (s *Store) putTable(ctx context.Context, bqTableID dialect.TableIdentifier, tableData *optimization.TableData) error {
columns := tableData.ReadOnlyInMemoryCols().ValidColumns()
columns := tableData.GetValidColumns()

messageDescriptor, err := columnsToMessageDescriptor(columns)
if err != nil {
Expand Down
20 changes: 14 additions & 6 deletions clients/bigquery/dialect/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,19 @@ func TestBigQueryDialect_BuildMergeQueries_SoftDelete(t *testing.T) {

func TestBigQueryDialect_BuildMergeQueries_JSONKey(t *testing.T) {
orderOIDCol := columns.NewColumn("order_oid", typing.Struct)
var cols columns.Columns
cols.AddColumn(orderOIDCol)
cols.AddColumn(columns.NewColumn("name", typing.String))
cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))
cols.AddColumn(columns.NewColumn(constants.OnlySetDeleteColumnMarker, typing.Boolean))

var cols []columns.Column
cols = append(cols, orderOIDCol)
cols = append(cols, columns.NewColumn("name", typing.String))
cols = append(cols, columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))
cols = append(cols, columns.NewColumn(constants.OnlySetDeleteColumnMarker, typing.Boolean))

var validCols []columns.Column
for _, col := range cols {
if col.IsValid() {
validCols = append(validCols, col)
}
}

fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns("customers.orders")
Expand All @@ -173,7 +181,7 @@ func TestBigQueryDialect_BuildMergeQueries_JSONKey(t *testing.T) {
"customers.orders_tmp",
[]columns.Column{orderOIDCol},
nil,
cols.ValidColumns(),
validCols,
false,
false,
)
Expand Down
44 changes: 36 additions & 8 deletions clients/databricks/dialect/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,20 @@ func TestDatabricksDialect_BuildDedupeQueries(t *testing.T) {
}
}

func buildColumns(colTypesMap map[string]typing.KindDetails) *columns.Columns {
func buildColumns(colTypesMap map[string]typing.KindDetails) []columns.Column {
var colNames []string
for colName := range colTypesMap {
colNames = append(colNames, colName)
}
// Sort the column names alphabetically to ensure deterministic order
slices.Sort(colNames)

var cols columns.Columns
var cols []columns.Column
for _, colName := range colNames {
cols.AddColumn(columns.NewColumn(colName, colTypesMap[colName]))
cols = append(cols, columns.NewColumn(colName, colTypesMap[colName]))
}

return &cols
return cols
}

func TestDatabricksDialect_BuildMergeQueries_SoftDelete(t *testing.T) {
Expand All @@ -156,13 +156,20 @@ func TestDatabricksDialect_BuildMergeQueries_SoftDelete(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns(fqTable)

var validCols []columns.Column
for _, col := range _cols {
if col.IsValid() {
validCols = append(validCols, col)
}
}

{
statements, err := DatabricksDialect{}.BuildMergeQueries(
fakeTableID,
fqTable,
[]columns.Column{columns.NewColumn("id", typing.Invalid)},
nil,
_cols.ValidColumns(),
validCols,
true,
false,
)
Expand Down Expand Up @@ -193,12 +200,19 @@ func TestDatabricksDialect_BuildMergeQueries(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns(fqTable)

var validCols []columns.Column
for _, col := range _cols {
if col.IsValid() {
validCols = append(validCols, col)
}
}

statements, err := DatabricksDialect{}.BuildMergeQueries(
fakeTableID,
fqTable,
[]columns.Column{columns.NewColumn("id", typing.Invalid)},
nil,
_cols.ValidColumns(),
validCols,
false,
false,
)
Expand All @@ -225,6 +239,13 @@ func TestDatabricksDialect_BuildMergeQueries_CompositeKey(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns(fqTable)

var validCols []columns.Column
for _, col := range _cols {
if col.IsValid() {
validCols = append(validCols, col)
}
}

statements, err := DatabricksDialect{}.BuildMergeQueries(
fakeTableID,
fqTable,
Expand All @@ -233,7 +254,7 @@ func TestDatabricksDialect_BuildMergeQueries_CompositeKey(t *testing.T) {
columns.NewColumn("another_id", typing.Invalid),
},
nil,
_cols.ValidColumns(),
validCols,
false,
false,
)
Expand Down Expand Up @@ -262,6 +283,13 @@ func TestDatabricksDialect_BuildMergeQueries_EscapePrimaryKeys(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns(fqTable)

var validCols []columns.Column
for _, col := range _cols {
if col.IsValid() {
validCols = append(validCols, col)
}
}

statements, err := DatabricksDialect{}.BuildMergeQueries(
fakeTableID,
fqTable,
Expand All @@ -270,7 +298,7 @@ func TestDatabricksDialect_BuildMergeQueries_EscapePrimaryKeys(t *testing.T) {
columns.NewColumn("group", typing.Invalid),
},
nil,
_cols.ValidColumns(),
validCols,
false,
false,
)
Expand Down
2 changes: 1 addition & 1 deletion clients/databricks/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (s Store) writeTemporaryTableFile(tableData *optimization.TableData, newTab
writer := csv.NewWriter(file)
writer.Comma = '\t'

columns := tableData.ReadOnlyInMemoryCols().ValidColumns()
columns := tableData.GetValidColumns()
for _, value := range tableData.Rows() {
var row []string
for _, col := range columns {
Expand Down
2 changes: 1 addition & 1 deletion clients/mssql/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimizati
}
}()

cols := tableData.ReadOnlyInMemoryCols().ValidColumns()
cols := tableData.GetValidColumns()
stmt, err := tx.Prepare(mssql.CopyIn(tempTableID.FullyQualifiedName(), mssql.BulkOptions{}, columns.ColumnNames(cols)...))
if err != nil {
return fmt.Errorf("failed to prepare bulk insert: %w", err)
Expand Down
27 changes: 17 additions & 10 deletions clients/redshift/dialect/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,15 @@ func getBasicColumnsForTest(compositeKey bool) result {
textToastCol := columns.NewColumn("toast_text", typing.String)
textToastCol.ToastColumn = true

var cols columns.Columns
cols.AddColumn(idCol)
cols.AddColumn(emailCol)
cols.AddColumn(columns.NewColumn("first_name", typing.String))
cols.AddColumn(columns.NewColumn("last_name", typing.String))
cols.AddColumn(columns.NewColumn("created_at", typing.TimestampNTZ))
cols.AddColumn(textToastCol)
cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))
cols.AddColumn(columns.NewColumn(constants.OnlySetDeleteColumnMarker, typing.Boolean))
var cols []columns.Column
cols = append(cols, idCol)
cols = append(cols, emailCol)
cols = append(cols, columns.NewColumn("first_name", typing.String))
cols = append(cols, columns.NewColumn("last_name", typing.String))
cols = append(cols, columns.NewColumn("created_at", typing.TimestampNTZ))
cols = append(cols, textToastCol)
cols = append(cols, columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))
cols = append(cols, columns.NewColumn(constants.OnlySetDeleteColumnMarker, typing.Boolean))

var pks []columns.Column
pks = append(pks, idCol)
Expand All @@ -185,9 +185,16 @@ func getBasicColumnsForTest(compositeKey bool) result {
pks = append(pks, emailCol)
}

var validCols []columns.Column
for _, col := range cols {
if col.IsValid() {
validCols = append(validCols, col)
}
}

return result{
PrimaryKeys: pks,
Columns: cols.ValidColumns(),
Columns: validCols,
}
}

Expand Down
4 changes: 2 additions & 2 deletions clients/redshift/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimizati
copyStmt := fmt.Sprintf(
`COPY %s (%s) FROM '%s' DELIMITER '\t' NULL AS '\\N' GZIP FORMAT CSV %s dateformat 'auto' timeformat 'auto';`,
tempTableID.FullyQualifiedName(),
strings.Join(sql.QuoteColumns(tableData.ReadOnlyInMemoryCols().ValidColumns(), s.Dialect()), ","),
strings.Join(sql.QuoteColumns(tableData.GetValidColumns(), s.Dialect()), ","),
s3Uri,
s.credentialsClause,
)
Expand All @@ -94,7 +94,7 @@ func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableID

writer := csv.NewWriter(gzipWriter)
writer.Comma = '\t'
_columns := tableData.ReadOnlyInMemoryCols().ValidColumns()
_columns := tableData.GetValidColumns()
columnToNewLengthMap := make(map[string]int32)
for _, value := range tableData.Rows() {
var row []string
Expand Down
2 changes: 1 addition & 1 deletion clients/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er
return nil
}

cols := tableData.ReadOnlyInMemoryCols().ValidColumns()
cols := tableData.GetValidColumns()
schema, err := parquetutil.BuildCSVSchema(cols)
if err != nil {
return fmt.Errorf("failed to generate parquet schema: %w", err)
Expand Down
7 changes: 1 addition & 6 deletions clients/shared/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,10 @@ func Merge(ctx context.Context, dwh destination.DataWarehouse, tableData *optimi
return fmt.Errorf("primary keys cannot be empty")
}

validColumns := cols.ValidColumns()
validColumns := tableData.GetValidColumns()
if len(validColumns) == 0 {
return fmt.Errorf("columns cannot be empty")
}
for _, column := range validColumns {
if !column.IsValid() {
return fmt.Errorf("column %q is invalid and should be skipped", column.Name())
}
}

mergeStatements, err := dwh.Dialect().BuildMergeQueries(
tableID,
Expand Down
46 changes: 37 additions & 9 deletions clients/snowflake/dialect/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,20 @@ func TestSnowflakeDialect_BuildIsNotToastValueExpression(t *testing.T) {
)
}

func buildColumns(colTypesMap map[string]typing.KindDetails) *columns.Columns {
colNames := []string{}
func buildColumns(colTypesMap map[string]typing.KindDetails) []columns.Column {
var colNames []string
for colName := range colTypesMap {
colNames = append(colNames, colName)
}
// Sort the column names alphabetically to ensure deterministic order
slices.Sort(colNames)

var cols columns.Columns
var cols []columns.Column
for _, colName := range colNames {
cols.AddColumn(columns.NewColumn(colName, colTypesMap[colName]))
cols = append(cols, columns.NewColumn(colName, colTypesMap[colName]))
}

return &cols
return cols
}

func TestSnowflakeDialect_BuildMergeQueries_SoftDelete(t *testing.T) {
Expand All @@ -106,13 +106,20 @@ func TestSnowflakeDialect_BuildMergeQueries_SoftDelete(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns(fqTable)

var validCols []columns.Column
for _, col := range _cols {
if col.IsValid() {
validCols = append(validCols, col)
}
}

{
statements, err := SnowflakeDialect{}.BuildMergeQueries(
fakeTableID,
fqTable,
[]columns.Column{columns.NewColumn("id", typing.Invalid)},
nil,
_cols.ValidColumns(),
validCols,
true,
false,
)
Expand Down Expand Up @@ -140,12 +147,19 @@ func TestSnowflakeDialect_BuildMergeQueries(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns(fqTable)

var validCols []columns.Column
for _, col := range _cols {
if col.IsValid() {
validCols = append(validCols, col)
}
}

statements, err := SnowflakeDialect{}.BuildMergeQueries(
fakeTableID,
fqTable,
[]columns.Column{columns.NewColumn("id", typing.Invalid)},
nil,
_cols.ValidColumns(),
validCols,
false,
false,
)
Expand All @@ -170,6 +184,13 @@ func TestSnowflakeDialect_BuildMergeQueries_CompositeKey(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns(fqTable)

var validCols []columns.Column
for _, col := range _cols {
if col.IsValid() {
validCols = append(validCols, col)
}
}

statements, err := SnowflakeDialect{}.BuildMergeQueries(
fakeTableID,
fqTable,
Expand All @@ -178,7 +199,7 @@ func TestSnowflakeDialect_BuildMergeQueries_CompositeKey(t *testing.T) {
columns.NewColumn("another_id", typing.Invalid),
},
nil,
_cols.ValidColumns(),
validCols,
false,
false,
)
Expand All @@ -205,6 +226,13 @@ func TestSnowflakeDialect_BuildMergeQueries_EscapePrimaryKeys(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns(fqTable)

var validCols []columns.Column
for _, col := range _cols {
if col.IsValid() {
validCols = append(validCols, col)
}
}

statements, err := SnowflakeDialect{}.BuildMergeQueries(
fakeTableID,
fqTable,
Expand All @@ -213,7 +241,7 @@ func TestSnowflakeDialect_BuildMergeQueries_EscapePrimaryKeys(t *testing.T) {
columns.NewColumn("group", typing.Invalid),
},
nil,
_cols.ValidColumns(),
validCols,
false,
false,
)
Expand Down
Loading

0 comments on commit 87e6d9a

Please sign in to comment.