diff --git a/.github/tools/matrixchecker/main.go b/.github/tools/matrixchecker/main.go index 274cd0ee25..2a3b6a9de4 100644 --- a/.github/tools/matrixchecker/main.go +++ b/.github/tools/matrixchecker/main.go @@ -19,6 +19,7 @@ var IgnorePackages = []string{ "warehouse/integrations/testhelper", "warehouse/integrations/testdata", "warehouse/integrations/config", + "warehouse/integrations/types", } func main() { diff --git a/Makefile b/Makefile index ebc9be6603..8e962b0f25 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ endif test-warehouse-integration: $(eval TEST_PATTERN = 'TestIntegration') $(eval TEST_CMD = SLOW=1 go test) - $(eval TEST_OPTIONS = -v -p 8 -timeout 30m -count 1 -race -run $(TEST_PATTERN) -coverprofile=profile.out -covermode=atomic -coverpkg=./...) + $(eval TEST_OPTIONS = -v -p 8 -timeout 30m -count 1 -run $(TEST_PATTERN) -coverprofile=profile.out -covermode=atomic -coverpkg=./...) $(TEST_CMD) $(TEST_OPTIONS) $(package) && touch $(TESTFILE) || true test-warehouse: test-warehouse-integration test-teardown diff --git a/warehouse/client/client.go b/warehouse/client/client.go index a1af287539..32b35954bd 100644 --- a/warehouse/client/client.go +++ b/warehouse/client/client.go @@ -82,10 +82,10 @@ func (cl *Client) bqQuery(statement string) (result warehouseutils.QueryResult, for { var row []bigquery.Value err = it.Next(&row) - if err == iterator.Done { - break - } if err != nil { + if errors.Is(err, iterator.Done) { + break + } return } var stringRow []string diff --git a/warehouse/identity/identity.go b/warehouse/identity/identity.go index dfc8b3ec0a..b425fe1dec 100644 --- a/warehouse/identity/identity.go +++ b/warehouse/identity/identity.go @@ -4,6 +4,7 @@ import ( "compress/gzip" "context" "database/sql" + "errors" "fmt" "io" "os" @@ -243,7 +244,7 @@ func (idr *Identity) addRules(txn *sqlmiddleware.Tx, loadFileNames []string, gzW var record []string record, err = eventReader.Read(columnNames) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { break } pkgLogger.Errorf("IDR: Error while reading merge rule file %s for loading in staging table locally:%s: %v", loadFileName, mergeRulesStagingTable, err) diff --git a/warehouse/integrations/azure-synapse/azure-synapse.go b/warehouse/integrations/azure-synapse/azure-synapse.go index 339450e03e..f096a76908 100644 --- a/warehouse/integrations/azure-synapse/azure-synapse.go +++ b/warehouse/integrations/azure-synapse/azure-synapse.go @@ -17,6 +17,10 @@ import ( "unicode/utf16" "unicode/utf8" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" + + "github.com/samber/lo" + "github.com/rudderlabs/rudder-go-kit/stats" sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/rudderlabs/rudder-server/warehouse/logfield" @@ -46,14 +50,14 @@ const ( ) const ( - mssqlStringLengthLimit = 512 - provider = warehouseutils.AzureSynapse - tableNameLimit = 127 + stringLengthLimit = 512 + provider = warehouseutils.AzureSynapse + tableNameLimit = 127 ) var errorsMappings []model.JobError -var rudderDataTypesMapToMssql = map[string]string{ +var rudderDataTypesMapToAzureSynapse = map[string]string{ "int": "bigint", "float": "decimal(28,10)", "string": "varchar(512)", @@ -62,7 +66,7 @@ var rudderDataTypesMapToMssql = map[string]string{ "json": "jsonb", } -var mssqlDataTypesMapToRudder = map[string]string{ +var azureSynapseDataTypesMapToRudder = map[string]string{ "integer": "int", "smallint": "int", "bigint": "int", @@ -176,7 +180,14 @@ func (as *AzureSynapse) connect() (*sqlmw.DB, error) { db, sqlmw.WithStats(as.stats), sqlmw.WithLogger(as.logger), - sqlmw.WithKeyAndValues(as.defaultLogFields()), + sqlmw.WithKeyAndValues([]any{ + logfield.SourceID, as.Warehouse.Source.ID, + logfield.SourceType, as.Warehouse.Source.SourceDefinition.Name, + logfield.DestinationID, as.Warehouse.Destination.ID, + logfield.DestinationType, as.Warehouse.Destination.DestinationDefinition.Name, + logfield.WorkspaceID, as.Warehouse.WorkspaceID, + logfield.Namespace, as.Namespace, + }), sqlmw.WithQueryTimeout(as.connectTimeout), sqlmw.WithSlowQueryThreshold(as.config.slowQueryThreshold), ) @@ -195,255 +206,379 @@ func (as *AzureSynapse) connectionCredentials() *credentials { } } -func (as *AzureSynapse) defaultLogFields() []any { - return []any{ +func columnsWithDataTypes(columns model.TableSchema, prefix string) string { + formattedColumns := lo.MapToSlice(columns, func(name, dataType string) string { + return fmt.Sprintf(`"%s%s" %s`, prefix, name, rudderDataTypesMapToAzureSynapse[dataType]) + }) + return strings.Join(formattedColumns, ",") +} + +func (*AzureSynapse) IsEmpty(context.Context, model.Warehouse) (empty bool, err error) { + return +} + +func (as *AzureSynapse) loadTable( + ctx context.Context, + tableName string, + tableSchemaInUpload model.TableSchema, + skipTempTableDelete bool, +) (*types.LoadTableStats, string, error) { + log := as.logger.With( logfield.SourceID, as.Warehouse.Source.ID, logfield.SourceType, as.Warehouse.Source.SourceDefinition.Name, logfield.DestinationID, as.Warehouse.Destination.ID, logfield.DestinationType, as.Warehouse.Destination.DestinationDefinition.Name, logfield.WorkspaceID, as.Warehouse.WorkspaceID, logfield.Namespace, as.Namespace, - } -} + logfield.TableName, tableName, + ) + log.Infow("started loading") -func columnsWithDataTypes(columns model.TableSchema, prefix string) string { - var arr []string - for name, dataType := range columns { - arr = append(arr, fmt.Sprintf(`"%s%s" %s`, prefix, name, rudderDataTypesMapToMssql[dataType])) + fileNames, err := as.LoadFileDownLoader.Download(ctx, tableName) + if err != nil { + return nil, "", fmt.Errorf("downloading load files: %w", err) } - return strings.Join(arr, ",") -} + defer func() { + misc.RemoveFilePaths(fileNames...) + }() -func (*AzureSynapse) IsEmpty(context.Context, model.Warehouse) (empty bool, err error) { - return -} + stagingTableName := warehouseutils.StagingTableName( + provider, + tableName, + tableNameLimit, + ) -func (as *AzureSynapse) loadTable(ctx context.Context, tableName string, tableSchemaInUpload model.TableSchema, skipTempTableDelete bool) (stagingTableName string, err error) { - as.logger.Infof("AZ: Starting load for table:%s", tableName) + // The use of prepared statements for creating temporary tables is not suitable in this context. + // Temporary tables in SQL Server have a limited scope and are automatically purged after the transaction commits. + // Therefore, creating normal tables is chosen as an alternative. + // + // For more information on this behavior: + // - See the discussion at https://github.com/denisenkom/go-mssqldb/issues/149 regarding prepared statements. + // - Refer to Microsoft's documentation on temporary tables at + // https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175528(v=sql.105)?redirectedfrom=MSDN. + log.Debugw("creating staging table") + createStagingTableStmt := fmt.Sprintf(` + SELECT + TOP 0 * INTO %[1]s.%[2]s + FROM + %[1]s.%[3]s;`, + as.Namespace, + stagingTableName, + tableName, + ) + if _, err = as.DB.ExecContext(ctx, createStagingTableStmt); err != nil { + return nil, "", fmt.Errorf("creating temporary table: %w", err) + } - previousColumnKeys := warehouseutils.SortColumnKeysFromColumnMap(as.Uploader.GetTableSchemaInWarehouse(tableName)) - // sort column names - sortedColumnKeys := warehouseutils.SortColumnKeysFromColumnMap(tableSchemaInUpload) + if !skipTempTableDelete { + defer func() { + as.dropStagingTable(ctx, stagingTableName) + }() + } - var extraColumns []string - for _, column := range previousColumnKeys { - if !slices.Contains(sortedColumnKeys, column) { - extraColumns = append(extraColumns, column) - } + txn, err := as.DB.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return nil, "", fmt.Errorf("begin transaction: %w", err) } - fileNames, err := as.LoadFileDownLoader.Download(ctx, tableName) - defer misc.RemoveFilePaths(fileNames...) + defer func() { + if err != nil { + _ = txn.Rollback() + } + }() + + sortedColumnKeys := warehouseutils.SortColumnKeysFromColumnMap( + tableSchemaInUpload, + ) + previousColumnKeys := warehouseutils.SortColumnKeysFromColumnMap( + as.Uploader.GetTableSchemaInWarehouse( + tableName, + ), + ) + extraColumns := lo.Filter(previousColumnKeys, func(item string, index int) bool { + return !slices.Contains(sortedColumnKeys, item) + }) + + log.Debugw("creating prepared stmt for loading data") + copyInStmt := mssql.CopyIn(as.Namespace+"."+stagingTableName, mssql.BulkOptions{CheckConstraints: false}, + append(sortedColumnKeys, extraColumns...)..., + ) + stmt, err := txn.PrepareContext(ctx, copyInStmt) if err != nil { - return + return nil, "", fmt.Errorf("preparing copyIn statement: %w", err) } - // create temporary table - stagingTableName = warehouseutils.StagingTableName(provider, tableName, tableNameLimit) - // prepared stmts cannot be used to create temp objects here. Will work in a txn, but will be purged after commit. - // https://github.com/denisenkom/go-mssqldb/issues/149, https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175528(v=sql.105)?redirectedfrom=MSDN - // sqlStatement := fmt.Sprintf(`CREATE TABLE ##%[2]s like %[1]s.%[3]s`, AZ.Namespace, stagingTableName, tableName) - // Hence falling back to creating normal tables - sqlStatement := fmt.Sprintf(`select top 0 * into %[1]s.%[2]s from %[1]s.%[3]s`, as.Namespace, stagingTableName, tableName) + log.Infow("loading data into staging table") + for _, fileName := range fileNames { + err = as.loadDataIntoStagingTable( + ctx, log, stmt, + fileName, sortedColumnKeys, + extraColumns, tableSchemaInUpload, + ) + if err != nil { + return nil, "", fmt.Errorf("loading data into staging table from file %s: %w", fileName, err) + } + } + if _, err = stmt.ExecContext(ctx); err != nil { + return nil, "", fmt.Errorf("executing copyIn statement: %w", err) + } - as.logger.Debugf("AZ: Creating temporary table for table:%s at %s\n", tableName, sqlStatement) - _, err = as.DB.ExecContext(ctx, sqlStatement) + log.Infow("deleting from load table") + rowsDeleted, err := as.deleteFromLoadTable( + ctx, txn, tableName, + stagingTableName, + ) if err != nil { - as.logger.Errorf("AZ: Error creating temporary table for table:%s: %v\n", tableName, err) - return + return nil, "", fmt.Errorf("delete from load table: %w", err) } - txn, err := as.DB.BeginTx(ctx, &sql.TxOptions{}) + log.Infow("inserting into load table") + rowsInserted, err := as.insertIntoLoadTable( + ctx, txn, tableName, + stagingTableName, sortedColumnKeys, + ) if err != nil { - as.logger.Errorf("AZ: Error while beginning a transaction in db for loading in table:%s: %v", tableName, err) - return + return nil, "", fmt.Errorf("insert into load table: %w", err) } - if !skipTempTableDelete { - defer as.dropStagingTable(ctx, stagingTableName) + log.Debugw("committing transaction") + if err = txn.Commit(); err != nil { + return nil, "", fmt.Errorf("commit transaction: %w", err) } - stmt, err := txn.PrepareContext(ctx, mssql.CopyIn(as.Namespace+"."+stagingTableName, mssql.BulkOptions{CheckConstraints: false}, append(sortedColumnKeys, extraColumns...)...)) + log.Infow("completed loading") + + return &types.LoadTableStats{ + RowsInserted: rowsInserted - rowsDeleted, + RowsUpdated: rowsDeleted, + }, stagingTableName, nil +} + +func (as *AzureSynapse) loadDataIntoStagingTable( + ctx context.Context, + log logger.Logger, + stmt *sql.Stmt, + fileName string, + sortedColumnKeys []string, + extraColumns []string, + tableSchemaInUpload model.TableSchema, +) error { + gzipFile, err := os.Open(fileName) if err != nil { - as.logger.Errorf("AZ: Error while preparing statement for transaction in db for loading in staging table:%s: %v\nstmt: %v", stagingTableName, err, stmt) - return + return fmt.Errorf("opening file: %w", err) } - for _, objectFileName := range fileNames { - var gzipFile *os.File - gzipFile, err = os.Open(objectFileName) - if err != nil { - as.logger.Errorf("AZ: Error opening file using os.Open for file:%s while loading to table %s", objectFileName, tableName) - return - } + defer func() { + _ = gzipFile.Close() + }() - var gzipReader *gzip.Reader - gzipReader, err = gzip.NewReader(gzipFile) - if err != nil { - as.logger.Errorf("AZ: Error reading file using gzip.NewReader for file:%s while loading to table %s", gzipFile, tableName) - gzipFile.Close() - return + gzipReader, err := gzip.NewReader(gzipFile) + if err != nil { + return fmt.Errorf("reading file: %w", err) + } + defer func() { + _ = gzipReader.Close() + }() - } - csvReader := csv.NewReader(gzipReader) - var csvRowsProcessedCount int - for { - var record []string - record, err = csvReader.Read() - if err != nil { - if err == io.EOF { - as.logger.Debugf("AZ: File reading completed while reading csv file for loading in staging table:%s: %s", stagingTableName, objectFileName) - break - } - as.logger.Errorf("AZ: Error while reading csv file %s for loading in staging table:%s: %v", objectFileName, stagingTableName, err) - txn.Rollback() - return - } - if len(sortedColumnKeys) != len(record) { - err = fmt.Errorf(`load file CSV columns for a row mismatch number found in upload schema. Columns in CSV row: %d, Columns in upload schema of table-%s: %d. Processed rows in csv file until mismatch: %d`, len(record), tableName, len(sortedColumnKeys), csvRowsProcessedCount) - as.logger.Error(err) - txn.Rollback() - return - } - var recordInterface []interface{} - for _, value := range record { - if strings.TrimSpace(value) == "" { - recordInterface = append(recordInterface, nil) - } else { - recordInterface = append(recordInterface, value) - } + csvReader := csv.NewReader(gzipReader) + + for { + record, err := csvReader.Read() + if err != nil { + if errors.Is(err, io.EOF) { + break } - var finalColumnValues []interface{} - for index, value := range recordInterface { - valueType := tableSchemaInUpload[sortedColumnKeys[index]] - if value == nil { - as.logger.Debugf("AZ : Found nil value for type : %s, column : %s", valueType, sortedColumnKeys[index]) - finalColumnValues = append(finalColumnValues, nil) - continue - } - strValue := value.(string) - switch valueType { - case "int": - var convertedValue int - if convertedValue, err = strconv.Atoi(strValue); err != nil { - as.logger.Errorf("AZ : Mismatch in datatype for type : %s, column : %s, value : %s, err : %v", valueType, sortedColumnKeys[index], strValue, err) - finalColumnValues = append(finalColumnValues, nil) - } else { - finalColumnValues = append(finalColumnValues, convertedValue) - } - case "float": - var convertedValue float64 - if convertedValue, err = strconv.ParseFloat(strValue, 64); err != nil { - as.logger.Errorf("MS : Mismatch in datatype for type : %s, column : %s, value : %s, err : %v", valueType, sortedColumnKeys[index], strValue, err) - finalColumnValues = append(finalColumnValues, nil) - } else { - finalColumnValues = append(finalColumnValues, convertedValue) - } - case "datetime": - var convertedValue time.Time - // TODO : handling milli? - if convertedValue, err = time.Parse(time.RFC3339, strValue); err != nil { - as.logger.Errorf("AZ : Mismatch in datatype for type : %s, column : %s, value : %s, err : %v", valueType, sortedColumnKeys[index], strValue, err) - finalColumnValues = append(finalColumnValues, nil) - } else { - finalColumnValues = append(finalColumnValues, convertedValue) - } - // TODO : handling all cases? - case "boolean": - var convertedValue bool - if convertedValue, err = strconv.ParseBool(strValue); err != nil { - as.logger.Errorf("AZ : Mismatch in datatype for type : %s, column : %s, value : %s, err : %v", valueType, sortedColumnKeys[index], strValue, err) - finalColumnValues = append(finalColumnValues, nil) - } else { - finalColumnValues = append(finalColumnValues, convertedValue) - } - case "string": - // This is needed to enable diacritic support Ex: Ü,ç Ç,©,∆,ß,á,ù,ñ,ê - // A substitute to this PR; https://github.com/denisenkom/go-mssqldb/pull/576/files - // An alternate to this approach is to use nvarchar(instead of varchar) - if len(strValue) > mssqlStringLengthLimit { - strValue = strValue[:mssqlStringLengthLimit] - } - var byteArr []byte - if hasDiacritics(strValue) { - as.logger.Debug("diacritics " + strValue) - byteArr = str2ucs2(strValue) - // This is needed as with above operation every character occupies 2 bytes - if len(byteArr) > mssqlStringLengthLimit { - byteArr = byteArr[:mssqlStringLengthLimit] - } - finalColumnValues = append(finalColumnValues, byteArr) - } else { - as.logger.Debug("non-diacritic : " + strValue) - finalColumnValues = append(finalColumnValues, strValue) - } - default: - finalColumnValues = append(finalColumnValues, value) - } + return fmt.Errorf("reading record from: %w", err) + } + if len(sortedColumnKeys) != len(record) { + return fmt.Errorf("mismatch in number of columns: actual count: %d, expected count: %d", + len(record), + len(sortedColumnKeys), + ) + } + + recordInterface := make([]interface{}, 0, len(record)) + for _, value := range record { + if strings.TrimSpace(value) == "" { + recordInterface = append(recordInterface, nil) + } else { + recordInterface = append(recordInterface, value) } - // This is needed for the copyIn to proceed successfully for azure synapse else will face below err for missing old columns - // mssql: Column count in target table does not match column count specified in input. - // If BCP command, ensure format file column count matches destination table. If SSIS data import, check column mappings are consistent with target. - for range extraColumns { + } + + finalColumnValues := make([]interface{}, 0, len(record)) + for index, value := range recordInterface { + valueType := tableSchemaInUpload[sortedColumnKeys[index]] + if value == nil { + log.Debugw("found nil value", + logfield.ColumnType, valueType, + logfield.ColumnName, sortedColumnKeys[index], + ) + finalColumnValues = append(finalColumnValues, nil) + continue } - _, err = stmt.ExecContext(ctx, finalColumnValues...) + + processedVal, err := as.ProcessColumnValue( + value.(string), + valueType, + ) if err != nil { - as.logger.Errorf("AZ: Error in exec statement for loading in staging table:%s: %v", stagingTableName, err) - txn.Rollback() - return + log.Warnw("mismatch in datatype", + logfield.ColumnType, valueType, + logfield.ColumnName, sortedColumnKeys[index], + logfield.ColumnValue, value, + logfield.Error, err, + ) + finalColumnValues = append(finalColumnValues, nil) + } else { + finalColumnValues = append(finalColumnValues, processedVal) } - csvRowsProcessedCount++ } - gzipReader.Close() - gzipFile.Close() - } - _, err = stmt.ExecContext(ctx) - if err != nil { - txn.Rollback() - as.logger.Errorf("AZ: Rollback transaction as there was error while loading staging table:%s: %v", stagingTableName, err) - return + // To ensure the successful execution of the 'copyIn' operation in Azure Synapse, + // it is necessary to handle the scenario where new columns are added to the target table. + // Without this adjustment, attempting to perform 'copyIn' when the column count in the + // target table does not match the column count specified in the input data will result + // in an error like: + // + // mssql: Column count in target table does not match column count specified in input. + // + // If this error is encountered, it is important to verify that the column structure in + // the source data matches the destination table's structure. If you are using the BCP command, + // ensure that the format file's column count matches the destination table. For SSIS data imports, + // double-check that the column mappings are consistent with the target table. + for range extraColumns { + finalColumnValues = append(finalColumnValues, nil) + } + _, err = stmt.ExecContext(ctx, finalColumnValues...) + if err != nil { + return fmt.Errorf("exec statement record: %w", err) + } + } + return nil +} + +func (as *AzureSynapse) ProcessColumnValue( + value string, + valueType string, +) (interface{}, error) { + switch valueType { + case model.IntDataType: + return strconv.Atoi(value) + case model.FloatDataType: + return strconv.ParseFloat(value, 64) + case model.DateTimeDataType: + return time.Parse(time.RFC3339, value) + case model.BooleanDataType: + return strconv.ParseBool(value) + case model.StringDataType: + if len(value) > stringLengthLimit { + value = value[:stringLengthLimit] + } + if !hasDiacritics(value) { + return value, nil + } else { + byteArr := str2ucs2(value) + if len(byteArr) > stringLengthLimit { + byteArr = byteArr[:stringLengthLimit] + } + return byteArr, nil + } + default: + return value, nil } - // deduplication process +} + +func (as *AzureSynapse) deleteFromLoadTable( + ctx context.Context, + txn *sqlmw.Tx, + tableName string, + stagingTableName string, +) (int64, error) { primaryKey := "id" if column, ok := primaryKeyMap[tableName]; ok { primaryKey = column } - partitionKey := "id" - if column, ok := partitionKeyMap[tableName]; ok { - partitionKey = column - } + var additionalJoinClause string if tableName == warehouseutils.DiscardsTable { - additionalJoinClause = fmt.Sprintf(`AND _source.%[3]s = "%[1]s"."%[2]s"."%[3]s" AND _source.%[4]s = "%[1]s"."%[2]s"."%[4]s"`, as.Namespace, tableName, "table_name", "column_name") - } - sqlStatement = fmt.Sprintf(`DELETE FROM "%[1]s"."%[2]s" FROM "%[1]s"."%[3]s" as _source where (_source.%[4]s = "%[1]s"."%[2]s"."%[4]s" %[5]s)`, as.Namespace, tableName, stagingTableName, primaryKey, additionalJoinClause) - as.logger.Infof("AZ: Deduplicate records for table:%s using staging table: %s\n", tableName, sqlStatement) - _, err = txn.ExecContext(ctx, sqlStatement) - if err != nil { - as.logger.Errorf("AZ: Error deleting from original table for dedup: %v\n", err) - txn.Rollback() - return + additionalJoinClause = fmt.Sprintf(`AND _source.%[3]s = %[1]q.%[2]q.%[3]q AND _source.%[4]s = %[1]q.%[2]q.%[4]q`, + as.Namespace, + tableName, + "table_name", + "column_name", + ) } - quotedColumnNames := warehouseutils.DoubleQuoteAndJoinByComma(sortedColumnKeys) - sqlStatement = fmt.Sprintf(`INSERT INTO "%[1]s"."%[2]s" (%[3]s) SELECT %[3]s FROM ( SELECT *, row_number() OVER (PARTITION BY %[5]s ORDER BY received_at DESC) AS _rudder_staging_row_number FROM "%[1]s"."%[4]s" ) AS _ where _rudder_staging_row_number = 1`, as.Namespace, tableName, quotedColumnNames, stagingTableName, partitionKey) - as.logger.Infof("AZ: Inserting records for table:%s using staging table: %s\n", tableName, sqlStatement) - _, err = txn.ExecContext(ctx, sqlStatement) + deleteStmt := fmt.Sprintf(` + DELETE FROM + %[1]q.%[2]q + FROM + %[1]q.%[3]q AS _source + WHERE + ( + _source.%[4]s = %[1]q.%[2]q.%[4]q %[5]s + );`, + as.Namespace, + tableName, + stagingTableName, + primaryKey, + additionalJoinClause, + ) + r, err := txn.ExecContext(ctx, deleteStmt) if err != nil { - as.logger.Errorf("AZ: Error inserting into original table: %v\n", err) - txn.Rollback() - return + return 0, fmt.Errorf("deleting from main table: %w", err) } + return r.RowsAffected() +} - if err = txn.Commit(); err != nil { - as.logger.Errorf("AZ: Error while committing transaction as there was error while loading staging table:%s: %v", stagingTableName, err) - return +func (as *AzureSynapse) insertIntoLoadTable( + ctx context.Context, + txn *sqlmw.Tx, + tableName string, + stagingTableName string, + sortedColumnKeys []string, +) (int64, error) { + partitionKey := "id" + if column, ok := partitionKeyMap[tableName]; ok { + partitionKey = column } - as.logger.Infof("AZ: Complete load for table:%s", tableName) - return + quotedColumnNames := warehouseutils.DoubleQuoteAndJoinByComma( + sortedColumnKeys, + ) + + insertStmt := fmt.Sprintf(` + INSERT INTO %[1]q.%[2]q (%[3]s) + SELECT + %[3]s + FROM + ( + SELECT + *, + ROW_NUMBER() OVER ( + PARTITION BY %[5]s + ORDER BY + received_at DESC + ) AS _rudder_staging_row_number + FROM + %[1]q.%[4]q + ) AS _ + WHERE + _rudder_staging_row_number = 1;`, + as.Namespace, + tableName, + quotedColumnNames, + stagingTableName, + partitionKey, + ) + + r, err := txn.ExecContext(ctx, insertStmt) + if err != nil { + return 0, fmt.Errorf("inserting intomain table: %w", err) + } + return r.RowsAffected() } // Taken from https://github.com/denisenkom/go-mssqldb/blob/master/tds.go @@ -469,7 +604,7 @@ func hasDiacritics(str string) bool { func (as *AzureSynapse) loadUserTables(ctx context.Context) (errorMap map[string]error) { errorMap = map[string]error{warehouseutils.IdentifiesTable: nil} as.logger.Infof("AZ: Starting load for identifies and users tables\n") - identifyStagingTable, err := as.loadTable(ctx, warehouseutils.IdentifiesTable, as.Uploader.GetTableSchemaInUpload(warehouseutils.IdentifiesTable), true) + _, identifyStagingTable, err := as.loadTable(ctx, warehouseutils.IdentifiesTable, as.Uploader.GetTableSchemaInUpload(warehouseutils.IdentifiesTable), true) if err != nil { errorMap[warehouseutils.IdentifiesTable] = err return @@ -556,8 +691,8 @@ func (as *AzureSynapse) loadUserTables(ctx context.Context) (errorMap map[string as.logger.Infof("AZ: Dedup records for table:%s using staging table: %s\n", warehouseutils.UsersTable, sqlStatement) _, err = tx.ExecContext(ctx, sqlStatement) if err != nil { - as.logger.Errorf("AZ: Error deleting from original table for dedup: %v\n", err) - tx.Rollback() + as.logger.Errorf("AZ: Error deleting from main table for dedup: %v\n", err) + _ = tx.Rollback() errorMap[warehouseutils.UsersTable] = err return } @@ -568,7 +703,7 @@ func (as *AzureSynapse) loadUserTables(ctx context.Context) (errorMap map[string if err != nil { as.logger.Errorf("AZ: Error inserting into users table from staging table: %v\n", err) - tx.Rollback() + _ = tx.Rollback() errorMap[warehouseutils.UsersTable] = err return } @@ -576,7 +711,7 @@ func (as *AzureSynapse) loadUserTables(ctx context.Context) (errorMap map[string err = tx.Commit() if err != nil { as.logger.Errorf("AZ: Error in transaction commit for users table: %v\n", err) - tx.Rollback() + _ = tx.Rollback() errorMap[warehouseutils.UsersTable] = err return } @@ -589,11 +724,10 @@ func (*AzureSynapse) DeleteBy(context.Context, []string, warehouseutils.DeleteBy func (as *AzureSynapse) CreateSchema(ctx context.Context) (err error) { sqlStatement := fmt.Sprintf(`IF NOT EXISTS ( SELECT * FROM sys.schemas WHERE name = N'%s' ) - EXEC('CREATE SCHEMA [%s]'); -`, as.Namespace, as.Namespace) + EXEC('CREATE SCHEMA [%s]');`, as.Namespace, as.Namespace) as.logger.Infof("SYNAPSE: Creating schema name in synapse for AZ:%s : %v", as.Warehouse.Destination.ID, sqlStatement) _, err = as.DB.ExecContext(ctx, sqlStatement) - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil } return @@ -645,8 +779,7 @@ func (as *AzureSynapse) AddColumns(ctx context.Context, tableName string, column WHERE OBJECT_ID = OBJECT_ID(N'%[1]s.%[2]s') AND name = '%[3]s' - ) -`, + )`, as.Namespace, tableName, columnsInfo[0].Name, @@ -662,7 +795,7 @@ func (as *AzureSynapse) AddColumns(ctx context.Context, tableName string, column )) for _, columnInfo := range columnsInfo { - queryBuilder.WriteString(fmt.Sprintf(` %q %s,`, columnInfo.Name, rudderDataTypesMapToMssql[columnInfo.Type])) + queryBuilder.WriteString(fmt.Sprintf(` %q %s,`, columnInfo.Name, rudderDataTypesMapToAzureSynapse[columnInfo.Type])) } query = strings.TrimSuffix(queryBuilder.String(), ",") @@ -790,7 +923,7 @@ func (as *AzureSynapse) FetchSchema(ctx context.Context) (model.Schema, model.Sc if _, ok := schema[tableName]; !ok { schema[tableName] = make(model.TableSchema) } - if datatype, ok := mssqlDataTypesMapToRudder[columnType]; ok { + if datatype, ok := azureSynapseDataTypesMapToRudder[columnType]; ok { schema[tableName][columnName] = datatype } else { if _, ok := unrecognizedSchema[tableName]; !ok { @@ -812,16 +945,21 @@ func (as *AzureSynapse) LoadUserTables(ctx context.Context) map[string]error { return as.loadUserTables(ctx) } -func (as *AzureSynapse) LoadTable(ctx context.Context, tableName string) error { - _, err := as.loadTable(ctx, tableName, as.Uploader.GetTableSchemaInUpload(tableName), false) - return err +func (as *AzureSynapse) LoadTable(ctx context.Context, tableName string) (*types.LoadTableStats, error) { + loadTableStat, _, err := as.loadTable( + ctx, + tableName, + as.Uploader.GetTableSchemaInUpload(tableName), + false, + ) + return loadTableStat, err } func (as *AzureSynapse) Cleanup(ctx context.Context) { if as.DB != nil { // extra check aside dropStagingTable(table) as.dropDanglingStagingTables(ctx) - as.DB.Close() + _ = as.DB.Close() } } @@ -837,22 +975,6 @@ func (*AzureSynapse) DownloadIdentityRules(context.Context, *misc.GZipWriter) (e return } -func (as *AzureSynapse) GetTotalCountInTable(ctx context.Context, tableName string) (int64, error) { - var ( - total int64 - err error - sqlStatement string - ) - sqlStatement = fmt.Sprintf(` - SELECT count(*) FROM "%[1]s"."%[2]s"; - `, - as.Namespace, - tableName, - ) - err = as.DB.QueryRowContext(ctx, sqlStatement).Scan(&total) - return total, err -} - func (as *AzureSynapse) Connect(_ context.Context, warehouse model.Warehouse) (client.Client, error) { as.Warehouse = warehouse as.Namespace = warehouse.Namespace diff --git a/warehouse/integrations/azure-synapse/azure_synapse_test.go b/warehouse/integrations/azure-synapse/azure_synapse_test.go index 00faca43c5..0d4ef9a164 100644 --- a/warehouse/integrations/azure-synapse/azure_synapse_test.go +++ b/warehouse/integrations/azure-synapse/azure_synapse_test.go @@ -6,9 +6,20 @@ import ( "fmt" "os" "strconv" + "strings" "testing" "time" + "github.com/golang/mock/gomock" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + azuresynapse "github.com/rudderlabs/rudder-server/warehouse/integrations/azure-synapse" + mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/rudderlabs/compose-test/compose" "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" @@ -32,7 +43,11 @@ func TestIntegration(t *testing.T) { t.Skip("Skipping tests. Add 'SLOW=1' env var to run test.") } - c := testcompose.New(t, compose.FilePaths([]string{"testdata/docker-compose.yml", "../testdata/docker-compose.jobsdb.yml", "../testdata/docker-compose.minio.yml"})) + c := testcompose.New(t, compose.FilePaths([]string{ + "testdata/docker-compose.yml", + "../testdata/docker-compose.jobsdb.yml", + "../testdata/docker-compose.minio.yml", + })) c.Start(context.Background()) misc.Init() @@ -63,6 +78,7 @@ func TestIntegration(t *testing.T) { bucketName := "testbucket" accessKeyID := "MYACCESSKEY" secretAccessKey := "MYSECRETKEY" + region := "us-east-1" minioEndpoint := fmt.Sprintf("localhost:%d", minioPort) @@ -91,10 +107,8 @@ func TestIntegration(t *testing.T) { t.Setenv("MINIO_SECRET_ACCESS_KEY", secretAccessKey) t.Setenv("MINIO_MINIO_ENDPOINT", minioEndpoint) t.Setenv("MINIO_SSL", "false") - t.Setenv("RSERVER_WAREHOUSE_AZURE_SYNAPSE_MAX_PARALLEL_LOADS", "8") t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) - t.Setenv("RSERVER_WAREHOUSE_AZURE_SYNAPSE_SLOW_QUERY_THRESHOLD", "0s") svcDone := make(chan struct{}) @@ -113,6 +127,9 @@ func TestIntegration(t *testing.T) { health.WaitUntilReady(ctx, t, serviceHealthEndpoint, time.Minute, time.Second, "serviceHealthEndpoint") t.Run("Events flow", func(t *testing.T) { + t.Setenv("RSERVER_WAREHOUSE_AZURE_SYNAPSE_MAX_PARALLEL_LOADS", "8") + t.Setenv("RSERVER_WAREHOUSE_AZURE_SYNAPSE_SLOW_QUERY_THRESHOLD", "0s") + jobsDB := testhelper.JobsDB(t, jobsDBPort) dsn := fmt.Sprintf("sqlserver://%s:%s@%s:%d?TrustServerCertificate=true&database=%s&encrypt=disable", @@ -155,7 +172,7 @@ func TestIntegration(t *testing.T) { Type: client.SQLClient, } - conf := map[string]interface{}{ + conf := map[string]any{ "bucketProvider": "MINIO", "bucketName": bucketName, "accessKeyID": accessKeyID, @@ -211,7 +228,7 @@ func TestIntegration(t *testing.T) { t.Run("Validations", func(t *testing.T) { dest := backendconfig.DestinationT{ ID: destinationID, - Config: map[string]interface{}{ + Config: map[string]any{ "host": host, "database": database, "user": user, @@ -239,4 +256,455 @@ func TestIntegration(t *testing.T) { } testhelper.VerifyConfigurationTest(t, dest) }) + + t.Run("Load Table", func(t *testing.T) { + const ( + sourceID = "test_source_id" + destinationID = "test_destination_id" + workspaceID = "test_workspace_id" + ) + + namespace := testhelper.RandSchema(destType) + + schemaInUpload := model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + } + schemaInWarehouse := model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + "extra_test_bool": "boolean", + "extra_test_datetime": "datetime", + "extra_test_float": "float", + "extra_test_int": "int", + "extra_test_string": "string", + } + + warehouse := model.Warehouse{ + Source: backendconfig.SourceT{ + ID: sourceID, + }, + Destination: backendconfig.DestinationT{ + ID: destinationID, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: destType, + }, + Config: map[string]any{ + "host": host, + "database": database, + "user": user, + "password": password, + "port": strconv.Itoa(azureSynapsePort), + "sslMode": "disable", + "namespace": "", + "bucketProvider": "MINIO", + "bucketName": bucketName, + "accessKeyID": accessKeyID, + "secretAccessKey": secretAccessKey, + "useSSL": false, + "endPoint": minioEndpoint, + "syncFrequency": "30", + "useRudderStorage": false, + }, + }, + WorkspaceID: workspaceID, + Namespace: namespace, + } + + fm, err := filemanager.New(&filemanager.Settings{ + Provider: warehouseutils.MINIO, + Config: map[string]any{ + "bucketName": bucketName, + "accessKeyID": accessKeyID, + "secretAccessKey": secretAccessKey, + "endPoint": minioEndpoint, + "forcePathStyle": true, + "s3ForcePathStyle": true, + "disableSSL": true, + "region": region, + "enableSSE": false, + "bucketProvider": warehouseutils.MINIO, + }, + }) + require.NoError(t, err) + + t.Run("schema does not exists", func(t *testing.T) { + tableName := "schema_not_exists_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + err := az.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + loadTableStat, err := az.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("table does not exists", func(t *testing.T) { + tableName := "table_not_exists_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + err := az.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = az.CreateSchema(ctx) + require.NoError(t, err) + + loadTableStat, err := az.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("merge", func(t *testing.T) { + tableName := "merge_test_table" + + t.Run("without dedup", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + err := az.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = az.CreateSchema(ctx) + require.NoError(t, err) + + err = az.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := az.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = az.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, az.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + cast(test_float AS float) AS test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + t.Run("with dedup", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + err := az.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = az.CreateSchema(ctx) + require.NoError(t, err) + + err = az.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := az.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, az.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + cast(test_float AS float) AS test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DedupTestRecords()) + }) + }) + t.Run("load file does not exists", func(t *testing.T) { + tableName := "load_file_not_exists_test_table" + + loadFiles := []warehouseutils.LoadFile{{ + Location: "http://localhost:1234/testbucket/rudder-warehouse-load-objects/load_file_not_exists_test_table/test_source_id/f31af97e-03e8-46d0-8a1a-1786cb85b22c-load_file_not_exists_test_table/load.csv.gz", + }} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + err := az.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = az.CreateSchema(ctx) + require.NoError(t, err) + + err = az.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := az.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("mismatch in number of columns", func(t *testing.T) { + tableName := "mismatch_columns_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + err := az.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = az.CreateSchema(ctx) + require.NoError(t, err) + + err = az.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := az.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("mismatch in schema", func(t *testing.T) { + tableName := "mismatch_schema_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + err := az.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = az.CreateSchema(ctx) + require.NoError(t, err) + + err = az.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := az.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, az.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + cast(test_float AS float) AS test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.MismatchSchemaTestRecords()) + }) + t.Run("discards", func(t *testing.T) { + tableName := warehouseutils.DiscardsTable + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema) + + az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + err := az.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = az.CreateSchema(ctx) + require.NoError(t, err) + + err = az.CreateTable(ctx, tableName, warehouseutils.DiscardsSchema) + require.NoError(t, err) + + loadTableStat, err := az.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(6)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, az.DB.DB, + fmt.Sprintf(` + SELECT + column_name, + column_value, + received_at, + row_id, + table_name, + uuid_ts + FROM + %q.%q + ORDER BY row_id ASC; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DiscardTestRecords()) + }) + }) +} + +func TestAzureSynapse_ProcessColumnValue(t *testing.T) { + testCases := []struct { + name string + data string + dataType string + expectedValue interface{} + wantError bool + }{ + { + name: "invalid integer", + data: "1.01", + dataType: model.IntDataType, + wantError: true, + }, + { + name: "valid integer", + data: "1", + dataType: model.IntDataType, + expectedValue: int64(1), + }, + { + name: "invalid float", + data: "test", + dataType: model.FloatDataType, + wantError: true, + }, + { + name: "valid float", + data: "1.01", + dataType: model.FloatDataType, + expectedValue: float64(1.01), + }, + { + name: "invalid boolean", + data: "test", + dataType: model.BooleanDataType, + wantError: true, + }, + { + name: "valid boolean", + data: "true", + dataType: model.BooleanDataType, + expectedValue: true, + }, + { + name: "invalid datetime", + data: "1", + dataType: model.DateTimeDataType, + wantError: true, + }, + { + name: "valid datetime", + data: "2020-01-01T00:00:00Z", + dataType: model.DateTimeDataType, + expectedValue: time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "valid string", + data: "test", + dataType: model.StringDataType, + expectedValue: "test", + }, + { + name: "valid string exceeding max length", + data: strings.Repeat("test", 200), + dataType: model.StringDataType, + expectedValue: strings.Repeat("test", 128), + }, + { + name: "valid string with diacritics", + data: "tést", + dataType: model.StringDataType, + expectedValue: []byte{0x74, 0x0, 0xe9, 0x0, 0x73, 0x0, 0x74, 0x0}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + az := azuresynapse.New(config.Default, logger.NOP, stats.Default) + + value, err := az.ProcessColumnValue(tc.data, tc.dataType) + if tc.wantError { + require.Error(t, err) + return + } + require.EqualValues(t, tc.expectedValue, value) + require.NoError(t, err) + }) + } +} + +func newMockUploader( + t testing.TB, + loadFiles []warehouseutils.LoadFile, + tableName string, + schemaInUpload model.TableSchema, + schemaInWarehouse model.TableSchema, +) warehouseutils.Uploader { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockUploader := mockuploader.NewMockUploader(ctrl) + mockUploader.EXPECT().UseRudderStorage().Return(false).AnyTimes() + mockUploader.EXPECT().GetLoadFilesMetadata(gomock.Any(), gomock.Any()).Return(loadFiles).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInUpload(tableName).Return(schemaInUpload).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInWarehouse(tableName).Return(schemaInWarehouse).AnyTimes() + + return mockUploader } diff --git a/warehouse/integrations/bigquery/bigquery.go b/warehouse/integrations/bigquery/bigquery.go index d8447a52d7..b483e1d75b 100644 --- a/warehouse/integrations/bigquery/bigquery.go +++ b/warehouse/integrations/bigquery/bigquery.go @@ -3,13 +3,19 @@ package bigquery import ( "context" "encoding/json" + "errors" "fmt" "regexp" "strings" "time" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" + + "github.com/samber/lo" + "cloud.google.com/go/bigquery" "golang.org/x/exp/slices" + bqService "google.golang.org/api/bigquery/v2" "google.golang.org/api/googleapi" "google.golang.org/api/iterator" "google.golang.org/api/option" @@ -45,7 +51,7 @@ type BigQuery struct { } } -type StagingLoadTable struct { +type loadTableResponse struct { partitionDate string stagingTableName string } @@ -147,7 +153,7 @@ func (bq *BigQuery) getMiddleware() *middleware.Client { if bq.middleware != nil { return bq.middleware } - middleware := middleware.New( + return middleware.New( bq.db, middleware.WithLogger(bq.logger), middleware.WithKeyAndValues( @@ -160,15 +166,12 @@ func (bq *BigQuery) getMiddleware() *middleware.Client { ), middleware.WithSlowQueryThreshold(bq.config.slowQueryThreshold), ) - return middleware } -func getTableSchema(columns model.TableSchema) []*bigquery.FieldSchema { - var schema []*bigquery.FieldSchema - for columnName, columnType := range columns { - schema = append(schema, &bigquery.FieldSchema{Name: columnName, Type: dataTypesMap[columnType]}) - } - return schema +func getTableSchema(tableSchema model.TableSchema) []*bigquery.FieldSchema { + return lo.MapToSlice(tableSchema, func(columnName, columnType string) *bigquery.FieldSchema { + return &bigquery.FieldSchema{Name: columnName, Type: dataTypesMap[columnType]} + }) } func (bq *BigQuery) DeleteTable(ctx context.Context, tableName string) (err error) { @@ -351,181 +354,301 @@ func partitionedTable(tableName, partitionDate string) string { return fmt.Sprintf(`%s$%v`, tableName, strings.ReplaceAll(partitionDate, "-", "")) } -func (bq *BigQuery) loadTable(ctx context.Context, tableName string, _, getLoadFileLocFromTableUploads, skipTempTableDelete bool) (stagingLoadTable StagingLoadTable, err error) { - bq.logger.Infof("BQ: Starting load for table:%s\n", tableName) - var loadFiles []warehouseutils.LoadFile - if getLoadFileLocFromTableUploads { - loadFile, err := bq.uploader.GetSingleLoadFile(ctx, tableName) - if err != nil { - return stagingLoadTable, err - } - loadFiles = append(loadFiles, loadFile) - } else { - loadFiles = bq.uploader.GetLoadFilesMetadata(ctx, warehouseutils.GetLoadFilesOptions{Table: tableName}) +func (bq *BigQuery) loadTable( + ctx context.Context, + tableName string, + skipTempTableDelete bool, +) (*types.LoadTableStats, *loadTableResponse, error) { + log := bq.logger.With( + logfield.SourceID, bq.warehouse.Source.ID, + logfield.SourceType, bq.warehouse.Source.SourceDefinition.Name, + logfield.DestinationID, bq.warehouse.Destination.ID, + logfield.DestinationType, bq.warehouse.Destination.DestinationDefinition.Name, + logfield.WorkspaceID, bq.warehouse.WorkspaceID, + logfield.Namespace, bq.namespace, + logfield.TableName, tableName, + logfield.LoadTableStrategy, bq.loadTableStrategy(), + ) + log.Infow("started loading") + + loadFileLocations, err := bq.loadFileLocations(ctx, tableName) + if err != nil { + return nil, nil, fmt.Errorf("getting load file locations: %w", err) } - gcsLocations := warehouseutils.GetGCSLocations(loadFiles, warehouseutils.GCSLocationOptions{}) - bq.logger.Infof("BQ: Loading data into table: %s in bigquery dataset: %s in project: %s", tableName, bq.namespace, bq.projectID) - gcsRef := bigquery.NewGCSReference(gcsLocations...) + + gcsRef := bigquery.NewGCSReference(warehouseutils.GetGCSLocations( + loadFileLocations, + warehouseutils.GCSLocationOptions{}, + )...) gcsRef.SourceFormat = bigquery.JSON gcsRef.MaxBadRecords = 0 gcsRef.IgnoreUnknownValues = false - loadTableByAppend := func() (err error) { - stagingLoadTable.partitionDate = time.Now().Format("2006-01-02") - outputTable := partitionedTable(tableName, stagingLoadTable.partitionDate) - // Tables created by RudderStack are ingestion-time partitioned table with pseudo column named _PARTITIONTIME. BigQuery automatically assigns rows to partitions based - // on the time when BigQuery ingests the data. To support custom field partitions, omitting loading into partitioned table like tableName$20191221 - // TODO: Support custom field partition on users & identifies tables - if bq.config.customPartitionsEnabled || slices.Contains(bq.config.customPartitionsEnabledWorkspaceIDs, bq.warehouse.WorkspaceID) { - outputTable = tableName - } + if bq.dedupEnabled() { + return bq.loadTableByMerge(ctx, tableName, gcsRef, log, skipTempTableDelete) + } + return bq.loadTableByAppend(ctx, tableName, gcsRef, log) +} - loader := bq.db.Dataset(bq.namespace).Table(outputTable).LoaderFrom(gcsRef) +func (bq *BigQuery) loadTableStrategy() string { + if bq.dedupEnabled() { + return "MERGE" + } + return "APPEND" +} - job, err := loader.Run(ctx) - if err != nil { - bq.logger.Errorf("BQ: Error initiating append load job: %v\n", err) - return - } - status, err := job.Wait(ctx) +func (bq *BigQuery) loadFileLocations( + ctx context.Context, + tableName string, +) ([]warehouseutils.LoadFile, error) { + switch tableName { + case warehouseutils.IdentityMappingsTable, warehouseutils.IdentityMergeRulesTable: + loadfile, err := bq.uploader.GetSingleLoadFile( + ctx, + tableName, + ) if err != nil { - bq.logger.Errorf("BQ: Error running append load job: %v\n", err) - return + return nil, fmt.Errorf("getting single load file for table %s: %w", tableName, err) } + return []warehouseutils.LoadFile{loadfile}, nil + default: + metadata := bq.uploader.GetLoadFilesMetadata( + ctx, + warehouseutils.GetLoadFilesOptions{Table: tableName}, + ) + return metadata, nil + } +} - if status.Err() != nil { - return status.Err() - } - return +// loadTableByAppend loads data into a table by appending to it +// +// In BigQuery, tables created by RudderStack are typically ingestion-time partitioned tables +// with a pseudo-column named _PARTITIONTIME. BigQuery automatically assigns rows to partitions +// based on the time when BigQuery ingests the data. To support custom field partitions, it is +// important to avoid loading data into partitioned tables with names like tableName$20191221. +// Instead, ensure that data is loaded into the appropriate ingestion-time partition, allowing +// BigQuery to manage partitioning based on the data's ingestion time. +// +// TODO: Support custom field partition on users & identifies tables +func (bq *BigQuery) loadTableByAppend( + ctx context.Context, + tableName string, + gcsRef *bigquery.GCSReference, + log logger.Logger, +) (*types.LoadTableStats, *loadTableResponse, error) { + partitionDate := time.Now().Format("2006-01-02") + + outputTable := partitionedTable( + tableName, + partitionDate, + ) + if bq.config.customPartitionsEnabled || slices.Contains(bq.config.customPartitionsEnabledWorkspaceIDs, bq.warehouse.WorkspaceID) { + outputTable = tableName } - loadTableByMerge := func() (err error) { - stagingTableName := warehouseutils.StagingTableName(provider, tableName, tableNameLimit) - stagingLoadTable.stagingTableName = stagingTableName - bq.logger.Infof("BQ: Loading data into temporary table: %s in bigquery dataset: %s in project: %s", stagingTableName, bq.namespace, bq.projectID) - stagingTableColMap := bq.uploader.GetTableSchemaInWarehouse(tableName) - sampleSchema := getTableSchema(stagingTableColMap) - metaData := &bigquery.TableMetadata{ - Schema: sampleSchema, - TimePartitioning: &bigquery.TimePartitioning{}, - } - tableRef := bq.db.Dataset(bq.namespace).Table(stagingTableName) - err = tableRef.Create(ctx, metaData) - if err != nil { - bq.logger.Infof("BQ: Error creating temporary staging table %s", stagingTableName) - return - } + log.Infow("loading data into main table") + job, err := bq.db.Dataset(bq.namespace).Table(outputTable).LoaderFrom(gcsRef).Run(ctx) + if err != nil { + return nil, nil, fmt.Errorf("moving data into main table: %w", err) + } - loader := bq.db.Dataset(bq.namespace).Table(stagingTableName).LoaderFrom(gcsRef) - job, err := loader.Run(ctx) - if err != nil { - bq.logger.Errorf("BQ: Error initiating staging table load job: %v\n", err) - return - } - status, err := job.Wait(ctx) - if err != nil { - bq.logger.Errorf("BQ: Error running staging table load job: %v\n", err) - return - } + log.Debugw("waiting for append job to complete", "jobID", job.ID()) + status, err := job.Wait(ctx) + if err != nil { + return nil, nil, fmt.Errorf("waiting for append job: %w", err) + } + if err := status.Err(); err != nil { + return nil, nil, fmt.Errorf("status for append job: %w", err) + } - if status.Err() != nil { - return status.Err() - } + log.Debugw("job statistics") + statistics, err := bq.jobStatistics(ctx, job) + if err != nil { + return nil, nil, fmt.Errorf("append job statistics: %w", err) + } - if !skipTempTableDelete { - defer bq.dropStagingTable(ctx, stagingTableName) - } + log.Infow("completed loading") - primaryKey := "id" - if column, ok := primaryKeyMap[tableName]; ok { - primaryKey = column - } + tableStats := &types.LoadTableStats{ + RowsInserted: statistics.Load.OutputRows, + } + response := &loadTableResponse{ + partitionDate: partitionDate, + } + return tableStats, response, nil +} - partitionKey := "id" - if column, ok := partitionKeyMap[tableName]; ok { - partitionKey = column - } +func (bq *BigQuery) jobStatistics( + ctx context.Context, + job *bigquery.Job, +) (*bqService.JobStatistics, error) { + serv, err := bqService.NewService( + ctx, + option.WithCredentialsJSON([]byte(warehouseutils.GetConfigValue(credentials, bq.warehouse))), + ) + if err != nil { + return nil, fmt.Errorf("creating service: %w", err) + } - tableColMap := bq.uploader.GetTableSchemaInWarehouse(tableName) - var tableColNames []string - for colName := range tableColMap { - tableColNames = append(tableColNames, fmt.Sprintf("`%s`", colName)) - } + bqJobGetCall := bqService.NewJobsService(serv).Get( + job.ProjectID(), + job.ID(), + ) + bqJob, err := bqJobGetCall.Context(ctx).Location(job.Location()).Fields("statistics").Do() + if err != nil { + return nil, fmt.Errorf("getting job: %w", err) + } + return bqJob.Statistics, nil +} - var stagingColumnNamesList, columnsWithValuesList []string - for _, str := range tableColNames { - stagingColumnNamesList = append(stagingColumnNamesList, fmt.Sprintf(`staging.%s`, str)) - columnsWithValuesList = append(columnsWithValuesList, fmt.Sprintf(`original.%[1]s = staging.%[1]s`, str)) - } - columnNames := strings.Join(tableColNames, ",") - stagingColumnNames := strings.Join(stagingColumnNamesList, ",") - columnsWithValues := strings.Join(columnsWithValuesList, ",") +func (bq *BigQuery) loadTableByMerge( + ctx context.Context, + tableName string, + gcsRef *bigquery.GCSReference, + log logger.Logger, + skipTempTableDelete bool, +) (*types.LoadTableStats, *loadTableResponse, error) { + stagingTableName := warehouseutils.StagingTableName( + provider, + tableName, + tableNameLimit, + ) - var primaryKeyList []string - for _, str := range strings.Split(primaryKey, ",") { - primaryKeyList = append(primaryKeyList, fmt.Sprintf(`original.%[1]s = staging.%[1]s`, strings.Trim(str, " "))) - } - primaryJoinClause := strings.Join(primaryKeyList, " AND ") - bqTable := func(name string) string { return fmt.Sprintf("`%s`.`%s`", bq.namespace, name) } + sampleSchema := getTableSchema(bq.uploader.GetTableSchemaInWarehouse( + tableName, + )) - var orderByClause string - if _, ok := tableColMap["received_at"]; ok { - orderByClause = "ORDER BY received_at DESC" - } + log.Debugw("creating staging table") + err := bq.db.Dataset(bq.namespace).Table(stagingTableName).Create(ctx, &bigquery.TableMetadata{ + Schema: sampleSchema, + TimePartitioning: &bigquery.TimePartitioning{}, + }) + if err != nil { + return nil, nil, fmt.Errorf("creating staging table: %w", err) + } - sqlStatement := fmt.Sprintf(`MERGE INTO %[1]s AS original - USING ( - SELECT * FROM ( - SELECT *, row_number() OVER (PARTITION BY %[7]s %[8]s) AS _rudder_staging_row_number FROM %[2]s - ) AS q WHERE _rudder_staging_row_number = 1 - ) AS staging - ON (%[3]s) - WHEN MATCHED THEN - UPDATE SET %[6]s - WHEN NOT MATCHED THEN - INSERT (%[4]s) VALUES (%[5]s)`, - bqTable(tableName), - bqTable(stagingTableName), - primaryJoinClause, - columnNames, - stagingColumnNames, - columnsWithValues, - partitionKey, - orderByClause, - ) - bq.logger.Infof("BQ: Dedup records for table:%s using staging table: %s\n", tableName, sqlStatement) + log.Infow("loading data into staging table") + job, err := bq.db.Dataset(bq.namespace).Table(stagingTableName).LoaderFrom(gcsRef).Run(ctx) + if err != nil { + return nil, nil, fmt.Errorf("loading into staging table: %w", err) + } - q := bq.db.Query(sqlStatement) - job, err = bq.getMiddleware().Run(ctx, q) - if err != nil { - bq.logger.Errorf("BQ: Error initiating merge load job: %v\n", err) - return - } - status, err = job.Wait(ctx) - if err != nil { - bq.logger.Errorf("BQ: Error running merge load job: %v\n", err) - return - } + log.Debugw("waiting for load job to complete", "jobID", job.ID()) + status, err := job.Wait(ctx) + if err != nil { + return nil, nil, fmt.Errorf("waiting for job: %w", err) + } + if err := status.Err(); err != nil { + return nil, nil, fmt.Errorf("status for job: %w", err) + } - if status.Err() != nil { - return status.Err() - } - return + if !skipTempTableDelete { + defer bq.dropStagingTable(ctx, stagingTableName) } - if !bq.dedupEnabled() { - err = loadTableByAppend() - return + tableColMap := bq.uploader.GetTableSchemaInWarehouse(tableName) + tableColNames := lo.MapToSlice(tableColMap, func(colName, _ string) string { + return fmt.Sprintf("`%s`", colName) + }) + + columnNames := strings.Join(tableColNames, ",") + stagingColumnNames := strings.Join(lo.Map(tableColNames, func(colName string, _ int) string { + return fmt.Sprintf(`staging.%s`, colName) + }), ",") + columnsWithValues := strings.Join(lo.Map(tableColNames, func(colName string, _ int) string { + return fmt.Sprintf(`original.%[1]s = staging.%[1]s`, colName) + }), ",") + + primaryKey := "id" + if column, ok := primaryKeyMap[tableName]; ok { + primaryKey = column + } + partitionKey := "id" + if column, ok := partitionKeyMap[tableName]; ok { + partitionKey = column } - err = loadTableByMerge() - return + primaryJoinClause := strings.Join(lo.Map(strings.Split(primaryKey, ","), func(str string, _ int) string { + return fmt.Sprintf(`original.%[1]s = staging.%[1]s`, strings.Trim(str, " ")) + }), " AND ") + + bqTable := func(name string) string { + return fmt.Sprintf("`%s`.`%s`", bq.namespace, name) + } + + var orderByClause string + if _, ok := tableColMap["received_at"]; ok { + orderByClause = "ORDER BY received_at DESC" + } + + mergeIntoStmt := fmt.Sprintf(` + MERGE INTO %[1]s AS original USING ( + SELECT + * + FROM + ( + SELECT + *, + row_number() OVER (PARTITION BY %[7]s %[8]s) AS _rudder_staging_row_number + FROM + %[2]s + ) AS q + WHERE + _rudder_staging_row_number = 1 + ) AS staging ON (%[3]s) WHEN MATCHED THEN + UPDATE + SET + %[6]s WHEN NOT MATCHED THEN INSERT (%[4]s) + VALUES + (%[5]s); +`, + bqTable(tableName), + bqTable(stagingTableName), + primaryJoinClause, + columnNames, + stagingColumnNames, + columnsWithValues, + partitionKey, + orderByClause, + ) + + log.Infow("merging data from staging table into main table") + job, err = bq.getMiddleware().Run(ctx, bq.db.Query(mergeIntoStmt)) + if err != nil { + return nil, nil, fmt.Errorf("moving data to main table: %w", err) + } + + log.Debugw("waiting for merge job to complete", "jobID", job.ID()) + status, err = job.Wait(ctx) + if err != nil { + return nil, nil, fmt.Errorf("waiting for merge job: %w", err) + } + if err := status.Err(); err != nil { + return nil, nil, fmt.Errorf("status for merge job: %w", err) + } + + log.Debugw("job statistics") + statistics, err := bq.jobStatistics(ctx, job) + if err != nil { + return nil, nil, fmt.Errorf("merge job statistics: %w", err) + } + + log.Infow("completed loading") + + tableStats := &types.LoadTableStats{ + RowsInserted: statistics.Query.DmlStats.InsertedRowCount, + RowsUpdated: statistics.Query.DmlStats.UpdatedRowCount, + } + response := &loadTableResponse{ + stagingTableName: stagingTableName, + } + return tableStats, response, nil } func (bq *BigQuery) LoadUserTables(ctx context.Context) (errorMap map[string]error) { errorMap = map[string]error{warehouseutils.IdentifiesTable: nil} bq.logger.Infof("BQ: Starting load for identifies and users tables\n") - identifyLoadTable, err := bq.loadTable(ctx, warehouseutils.IdentifiesTable, true, false, true) + _, identifyLoadTable, err := bq.loadTable(ctx, warehouseutils.IdentifiesTable, true) if err != nil { errorMap[warehouseutils.IdentifiesTable] = err return @@ -770,10 +893,10 @@ func (bq *BigQuery) dropDanglingStagingTables(ctx context.Context) bool { for { var values []bigquery.Value err := it.Next(&values) - if err == iterator.Done { - break - } if err != nil { + if errors.Is(err, iterator.Done) { + break + } bq.logger.Errorf("BQ: Error in processing fetched staging tables from information schema in dataset %v : %v", bq.namespace, err) return false } @@ -793,41 +916,41 @@ func (bq *BigQuery) dropDanglingStagingTables(ctx context.Context) bool { return delSuccess } -func (bq *BigQuery) IsEmpty(ctx context.Context, warehouse model.Warehouse) (empty bool, err error) { - empty = true +func (bq *BigQuery) IsEmpty( + ctx context.Context, + warehouse model.Warehouse, +) (bool, error) { bq.warehouse = warehouse bq.namespace = warehouse.Namespace bq.projectID = strings.TrimSpace(warehouseutils.GetConfigValue(project, bq.warehouse)) - bq.logger.Infof("BQ: Connecting to BigQuery in project: %s", bq.projectID) + + var err error bq.db, err = bq.connect(ctx, BQCredentials{ ProjectID: bq.projectID, Credentials: warehouseutils.GetConfigValue(credentials, bq.warehouse), }) if err != nil { - return + return false, fmt.Errorf("connecting to bigquery: %v", err) } defer func() { _ = bq.db.Close() }() tables := []string{"tracks", "pages", "screens", "identifies", "aliases"} for _, tableName := range tables { - var exists bool - exists, err = bq.tableExists(ctx, tableName) + exists, err := bq.tableExists(ctx, tableName) if err != nil { - return + return false, fmt.Errorf("checking if table %s exists: %v", tableName, err) } if !exists { continue } - count, err := bq.GetTotalCountInTable(ctx, tableName) + + metadata, err := bq.db.Dataset(bq.namespace).Table(tableName).Metadata(ctx) if err != nil { - return empty, err - } - if count > 0 { - empty = false - return empty, nil + return false, fmt.Errorf("getting metadata for table %s: %v", tableName, err) } + return metadata.NumRows == 0, nil } - return + return true, nil } func (bq *BigQuery) Setup(ctx context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { @@ -847,16 +970,13 @@ func (*BigQuery) TestConnection(context.Context, model.Warehouse) (err error) { return nil } -func (bq *BigQuery) LoadTable(ctx context.Context, tableName string) error { - var getLoadFileLocFromTableUploads bool - switch tableName { - case warehouseutils.IdentityMappingsTable, warehouseutils.IdentityMergeRulesTable: - getLoadFileLocFromTableUploads = true - default: - getLoadFileLocFromTableUploads = false - } - _, err := bq.loadTable(ctx, tableName, false, getLoadFileLocFromTableUploads, false) - return err +func (bq *BigQuery) LoadTable(ctx context.Context, tableName string) (*types.LoadTableStats, error) { + loadTableStat, _, err := bq.loadTable( + ctx, + tableName, + false, + ) + return loadTableStat, err } func (bq *BigQuery) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { @@ -932,10 +1052,10 @@ func (bq *BigQuery) FetchSchema(ctx context.Context) (model.Schema, model.Schema var values []bigquery.Value err := it.Next(&values) - if err == iterator.Done { - break - } if err != nil { + if errors.Is(err, iterator.Done) { + break + } return nil, nil, fmt.Errorf("iterating schema: %w", err) } @@ -975,12 +1095,14 @@ func (bq *BigQuery) Cleanup(context.Context) { func (bq *BigQuery) LoadIdentityMergeRulesTable(ctx context.Context) (err error) { identityMergeRulesTable := warehouseutils.IdentityMergeRulesWarehouseTableName(warehouseutils.BQ) - return bq.LoadTable(ctx, identityMergeRulesTable) + _, err = bq.LoadTable(ctx, identityMergeRulesTable) + return err } func (bq *BigQuery) LoadIdentityMappingsTable(ctx context.Context) (err error) { identityMappingsTable := warehouseutils.IdentityMappingsWarehouseTableName(warehouseutils.BQ) - return bq.LoadTable(ctx, identityMappingsTable) + _, err = bq.LoadTable(ctx, identityMappingsTable) + return err } func (bq *BigQuery) tableExists(ctx context.Context, tableName string) (exists bool, err error) { @@ -1079,10 +1201,10 @@ func (bq *BigQuery) DownloadIdentityRules(ctx context.Context, gzWriter *misc.GZ var values []bigquery.Value err := it.Next(&values) - if err == iterator.Done { - break - } if err != nil { + if errors.Is(err, iterator.Done) { + break + } return err } var anonId, userId string @@ -1126,43 +1248,6 @@ func (bq *BigQuery) DownloadIdentityRules(ctx context.Context, gzWriter *misc.GZ return } -func (bq *BigQuery) GetTotalCountInTable(ctx context.Context, tableName string) (int64, error) { - var ( - total int64 - err error - sqlStatement string - ok bool - - it *bigquery.RowIterator - values []bigquery.Value - ) - sqlStatement = fmt.Sprintf(` - SELECT count(*) FROM %[1]s.%[2]s; - `, - bq.namespace, - tableName, - ) - - query := bq.db.Query(sqlStatement) - if it, err = bq.getMiddleware().Read(ctx, query); err != nil { - return 0, fmt.Errorf("creating row iterator: %w", err) - } - - err = it.Next(&values) - if err == iterator.Done { - return 0, nil - } - if err != nil { - return 0, fmt.Errorf("iterating through rows: %w", err) - } - - if total, ok = values[0].(int64); !ok { - return 0, fmt.Errorf("converting value to int64: %w", err) - } - - return total, nil -} - func (bq *BigQuery) Connect(ctx context.Context, warehouse model.Warehouse) (client.Client, error) { bq.warehouse = warehouse bq.namespace = warehouse.Namespace diff --git a/warehouse/integrations/bigquery/bigquery_test.go b/warehouse/integrations/bigquery/bigquery_test.go index f2150935ff..0af0d31253 100644 --- a/warehouse/integrations/bigquery/bigquery_test.go +++ b/warehouse/integrations/bigquery/bigquery_test.go @@ -10,6 +10,15 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" + "golang.org/x/exp/slices" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-go-kit/logger" + mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "cloud.google.com/go/bigquery" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -388,6 +397,575 @@ func TestIntegration(t *testing.T) { } testhelper.VerifyConfigurationTest(t, dest) }) + + t.Run("Load Table", func(t *testing.T) { + const ( + sourceID = "test_source_id" + destinationID = "test_destination_id" + workspaceID = "test_workspace_id" + ) + + namespace := testhelper.RandSchema(destType) + + t.Cleanup(func() { + require.Eventually(t, func() bool { + if err := db.Dataset(namespace).DeleteWithContents(ctx); err != nil { + t.Logf("error deleting dataset: %v", err) + return false + } + return true + }, + time.Minute, + time.Second, + ) + }) + + schemaInUpload := model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + } + schemaInWarehouse := model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + "extra_test_bool": "boolean", + "extra_test_datetime": "datetime", + "extra_test_float": "float", + "extra_test_int": "int", + "extra_test_string": "string", + } + + credentials, err := bqHelper.GetBQTestCredentials() + require.NoError(t, err) + + warehouse := model.Warehouse{ + Source: backendconfig.SourceT{ + ID: sourceID, + }, + Destination: backendconfig.DestinationT{ + ID: destinationID, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: destType, + }, + Config: map[string]any{ + "project": credentials.ProjectID, + "location": credentials.Location, + "bucketName": credentials.BucketName, + "credentials": credentials.Credentials, + "namespace": namespace, + }, + }, + WorkspaceID: workspaceID, + Namespace: namespace, + } + + fm, err := filemanager.New(&filemanager.Settings{ + Provider: warehouseutils.GCS, + Config: map[string]any{ + "project": credentials.ProjectID, + "location": credentials.Location, + "bucketName": credentials.BucketName, + "credentials": credentials.Credentials, + }, + }) + require.NoError(t, err) + + t.Run("schema does not exists", func(t *testing.T) { + tableName := "schema_not_exists_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + bq := whbigquery.New(config.Default, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("table does not exists", func(t *testing.T) { + tableName := "table_not_exists_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + bq := whbigquery.New(config.Default, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = bq.CreateSchema(ctx) + require.NoError(t, err) + + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("merge", func(t *testing.T) { + tableName := "merge_test_table" + + c := config.New() + c.Set("Warehouse.bigquery.isDedupEnabled", true) + + t.Run("without dedup", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + bq := whbigquery.New(c, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = bq.CreateSchema(ctx) + require.NoError(t, err) + + err = bq.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := bqHelper.RetrieveRecordsFromWarehouse(t, db, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s + ORDER BY + id; + `, + fmt.Sprintf("`%s`.`%s`", namespace, tableName), + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + t.Run("with dedup", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.json.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + bq := whbigquery.New(c, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = bq.CreateSchema(ctx) + require.NoError(t, err) + + err = bq.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := bqHelper.RetrieveRecordsFromWarehouse(t, db, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s + ORDER BY + id; + `, + fmt.Sprintf("`%s`.`%s`", namespace, tableName), + ), + ) + require.Equal(t, records, testhelper.DedupTestRecords()) + }) + }) + t.Run("append", func(t *testing.T) { + tableName := "append_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + bq := whbigquery.New(config.Default, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = bq.CreateSchema(ctx) + require.NoError(t, err) + + err = bq.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := bqHelper.RetrieveRecordsFromWarehouse(t, db, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + WHERE + _PARTITIONTIME BETWEEN TIMESTAMP('%s') AND TIMESTAMP('%s') + ORDER BY + id; + `, + namespace, + tableName, + time.Now().Add(-24*time.Hour).Format("2006-01-02"), + time.Now().Add(+24*time.Hour).Format("2006-01-02"), + ), + ) + require.Equal(t, records, testhelper.AppendTestRecords()) + }) + t.Run("load file does not exists", func(t *testing.T) { + tableName := "load_file_not_exists_test_table" + + loadFiles := []warehouseutils.LoadFile{{ + Location: "https://storage.googleapis.com/project/rudder-warehouse-load-objects/load_file_not_exists_test_table/test_source_id/2e04b6bd-8007-461e-a338-91224a8b7d3d-load_file_not_exists_test_table/load.json.gz", + }} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + bq := whbigquery.New(config.Default, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = bq.CreateSchema(ctx) + require.NoError(t, err) + + err = bq.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("mismatch in number of columns", func(t *testing.T) { + tableName := "mismatch_columns_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.json.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + bq := whbigquery.New(config.Default, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = bq.CreateSchema(ctx) + require.NoError(t, err) + + err = bq.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("mismatch in schema", func(t *testing.T) { + tableName := "mismatch_schema_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.json.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + bq := whbigquery.New(config.Default, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = bq.CreateSchema(ctx) + require.NoError(t, err) + + err = bq.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("discards", func(t *testing.T) { + tableName := warehouseutils.DiscardsTable + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.json.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema) + + bq := whbigquery.New(config.Default, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = bq.CreateSchema(ctx) + require.NoError(t, err) + + err = bq.CreateTable(ctx, tableName, warehouseutils.DiscardsSchema) + require.NoError(t, err) + + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(6)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := bqHelper.RetrieveRecordsFromWarehouse(t, db, + fmt.Sprintf(` + SELECT + column_name, + column_value, + received_at, + row_id, + table_name, + uuid_ts + FROM + %s + ORDER BY row_id ASC; + `, + fmt.Sprintf("`%s`.`%s`", namespace, tableName), + ), + ) + require.Equal(t, records, testhelper.DiscardTestRecords()) + }) + t.Run("custom partition", func(t *testing.T) { + tableName := "partition_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.json.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader( + t, loadFiles, tableName, schemaInUpload, + schemaInWarehouse, + ) + + c := config.New() + c.Set("Warehouse.bigquery.customPartitionsEnabled", true) + c.Set("Warehouse.bigquery.customPartitionsEnabledWorkspaceIDs", []string{workspaceID}) + + bq := whbigquery.New(c, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = bq.CreateSchema(ctx) + require.NoError(t, err) + + err = bq.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := bqHelper.RetrieveRecordsFromWarehouse(t, db, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + WHERE + _PARTITIONTIME BETWEEN TIMESTAMP('%s') AND TIMESTAMP('%s') + ORDER BY + id; + `, + namespace, + tableName, + time.Now().Add(-24*time.Hour).Format("2006-01-02"), + time.Now().Add(+24*time.Hour).Format("2006-01-02"), + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + }) + + t.Run("IsEmpty", func(t *testing.T) { + namespace := testhelper.RandSchema(warehouseutils.BQ) + + t.Cleanup(func() { + require.Eventually(t, func() bool { + if err := db.Dataset(namespace).DeleteWithContents(ctx); err != nil { + t.Logf("error deleting dataset: %v", err) + return false + } + return true + }, + time.Minute, + time.Second, + ) + }) + + ctx := context.Background() + + credentials, err := bqHelper.GetBQTestCredentials() + require.NoError(t, err) + + warehouse := model.Warehouse{ + Source: backendconfig.SourceT{ + ID: sourceID, + }, + Destination: backendconfig.DestinationT{ + ID: destinationID, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: warehouseutils.BQ, + }, + Config: map[string]any{ + "project": credentials.ProjectID, + "location": credentials.Location, + "bucketName": credentials.BucketName, + "credentials": credentials.Credentials, + "namespace": namespace, + }, + }, + WorkspaceID: workspaceID, + Namespace: namespace, + } + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockUploader := mockuploader.NewMockUploader(ctrl) + mockUploader.EXPECT().UseRudderStorage().Return(false).AnyTimes() + + insertRecords := func(t testing.TB, tableName string) { + t.Helper() + + query := db.Query(` + INSERT INTO ` + tableName + ` ( + id, received_at, test_bool, test_datetime, + test_float, test_int, test_string + ) + VALUES + ( + '1', '2020-01-01 00:00:00', true, + '2020-01-01 00:00:00', 1.1, 1, 'test' + );`, + ) + job, err := query.Run(ctx) + require.NoError(t, err) + + status, err := job.Wait(ctx) + require.NoError(t, err) + require.Nil(t, status.Err()) + } + + t.Run("tables doesn't exists", func(t *testing.T) { + bq := whbigquery.New(config.Default, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + isEmpty, err := bq.IsEmpty(ctx, warehouse) + require.NoError(t, err) + require.True(t, isEmpty) + }) + t.Run("tables empty", func(t *testing.T) { + bq := whbigquery.New(config.Default, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = bq.CreateSchema(ctx) + require.NoError(t, err) + + tables := []string{"pages", "screens"} + for _, table := range tables { + err = bq.CreateTable(ctx, table, model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + }) + require.NoError(t, err) + } + + isEmpty, err := bq.IsEmpty(ctx, warehouse) + require.NoError(t, err) + require.True(t, isEmpty) + }) + t.Run("tables not empty", func(t *testing.T) { + bq := whbigquery.New(config.Default, logger.NOP) + err := bq.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + insertRecords(t, fmt.Sprintf("`%s`.`%s`", namespace, "pages")) + insertRecords(t, fmt.Sprintf("`%s`.`%s`", namespace, "screens")) + + isEmpty, err := bq.IsEmpty(ctx, warehouse) + require.NoError(t, err) + require.False(t, isEmpty) + }) + }) +} + +func newMockUploader( + t testing.TB, + loadFiles []warehouseutils.LoadFile, + tableName string, + schemaInUpload model.TableSchema, + schemaInWarehouse model.TableSchema, +) warehouseutils.Uploader { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockUploader := mockuploader.NewMockUploader(ctrl) + mockUploader.EXPECT().UseRudderStorage().Return(false).AnyTimes() + mockUploader.EXPECT().GetLoadFilesMetadata(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, options warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { + return slices.Clone(loadFiles) + }, + ).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInUpload(tableName).Return(schemaInUpload).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInWarehouse(tableName).Return(schemaInWarehouse).AnyTimes() + + return mockUploader } func loadFilesEventsMap() testhelper.EventsCountMap { diff --git a/warehouse/integrations/bigquery/testhelper/setup.go b/warehouse/integrations/bigquery/testhelper/setup.go index eb44eb8e17..d08a0dd29b 100644 --- a/warehouse/integrations/bigquery/testhelper/setup.go +++ b/warehouse/integrations/bigquery/testhelper/setup.go @@ -1,9 +1,20 @@ package testhelper import ( + "context" "encoding/json" + "errors" "fmt" "os" + "testing" + "time" + + "cloud.google.com/go/bigquery" + + "github.com/samber/lo" + "github.com/spf13/cast" + "github.com/stretchr/testify/require" + "google.golang.org/api/iterator" ) type TestCredentials struct { @@ -34,3 +45,35 @@ func IsBQTestCredentialsAvailable() bool { _, err := GetBQTestCredentials() return err == nil } + +// RetrieveRecordsFromWarehouse retrieves records from the warehouse based on the given query. +// It returns a slice of slices, where each inner slice represents a record's values. +func RetrieveRecordsFromWarehouse( + t testing.TB, + db *bigquery.Client, + query string, +) [][]string { + t.Helper() + + it, err := db.Query(query).Read(context.Background()) + require.NoError(t, err) + + var records [][]string + for { + var row []bigquery.Value + if errors.Is(it.Next(&row), iterator.Done) { + break + } + require.NoError(t, err) + + records = append(records, lo.Map(row, func(item bigquery.Value, index int) string { + switch item := item.(type) { + case time.Time: + return item.Format(time.RFC3339) + default: + return cast.ToString(item) + } + })) + } + return records +} diff --git a/warehouse/integrations/clickhouse/clickhouse.go b/warehouse/integrations/clickhouse/clickhouse.go index 2d0182bd2a..0b57498a99 100644 --- a/warehouse/integrations/clickhouse/clickhouse.go +++ b/warehouse/integrations/clickhouse/clickhouse.go @@ -20,6 +20,8 @@ import ( "strings" "time" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" + sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/rudderlabs/rudder-server/warehouse/logfield" @@ -688,7 +690,7 @@ func (ch *Clickhouse) loadTablesFromFilesNamesWithRetry(ctx context.Context, tab var record []string record, err = csvReader.Read() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { ch.logger.Debugf("%s File reading completed while reading csv file for loading in table for objectFileName:%s", ch.GetLogIdentifier(tableName), objectFileName) break } @@ -1033,9 +1035,25 @@ func (ch *Clickhouse) LoadUserTables(ctx context.Context) (errorMap map[string]e return } -func (ch *Clickhouse) LoadTable(ctx context.Context, tableName string) error { - err := ch.loadTable(ctx, tableName, ch.Uploader.GetTableSchemaInUpload(tableName)) - return err +func (ch *Clickhouse) LoadTable(ctx context.Context, tableName string) (*types.LoadTableStats, error) { + preLoadTableCount, err := ch.totalCountIntable(ctx, tableName) + if err != nil { + return nil, fmt.Errorf("pre load table count: %w", err) + } + + err = ch.loadTable(ctx, tableName, ch.Uploader.GetTableSchemaInUpload(tableName)) + if err != nil { + return nil, fmt.Errorf("loading table: %w", err) + } + + postLoadTableCount, err := ch.totalCountIntable(ctx, tableName) + if err != nil { + return nil, fmt.Errorf("post load table count: %w", err) + } + + return &types.LoadTableStats{ + RowsInserted: postLoadTableCount - preLoadTableCount, + }, nil } func (ch *Clickhouse) Cleanup(context.Context) { @@ -1060,7 +1078,7 @@ func (*Clickhouse) IsEmpty(context.Context, model.Warehouse) (empty bool, err er return } -func (ch *Clickhouse) GetTotalCountInTable(ctx context.Context, tableName string) (int64, error) { +func (ch *Clickhouse) totalCountIntable(ctx context.Context, tableName string) (int64, error) { var ( total int64 err error diff --git a/warehouse/integrations/clickhouse/clickhouse_test.go b/warehouse/integrations/clickhouse/clickhouse_test.go index 4f230511cf..ea386d1f14 100644 --- a/warehouse/integrations/clickhouse/clickhouse_test.go +++ b/warehouse/integrations/clickhouse/clickhouse_test.go @@ -573,14 +573,9 @@ func TestClickhouse_LoadTableRoundTrip(t *testing.T) { } t.Log("Loading data into table") - err = ch.LoadTable(ctx, table) + _, err = ch.LoadTable(ctx, table) require.NoError(t, err) - t.Log("Checking table count") - count, err := ch.GetTotalCountInTable(ctx, table) - require.NoError(t, err) - require.EqualValues(t, 2, count) - t.Log("Drop table") err = ch.DropTable(ctx, table) require.NoError(t, err) diff --git a/warehouse/integrations/datalake/datalake.go b/warehouse/integrations/datalake/datalake.go index e33d87e3bd..4251009552 100644 --- a/warehouse/integrations/datalake/datalake.go +++ b/warehouse/integrations/datalake/datalake.go @@ -6,6 +6,8 @@ import ( "regexp" "time" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" schemarepository "github.com/rudderlabs/rudder-server/warehouse/integrations/datalake/schema-repository" @@ -77,9 +79,9 @@ func (d *Datalake) AlterColumn(ctx context.Context, tableName, columnName, colum return d.SchemaRepository.AlterColumn(ctx, tableName, columnName, columnType) } -func (d *Datalake) LoadTable(_ context.Context, tableName string) error { +func (d *Datalake) LoadTable(_ context.Context, tableName string) (*types.LoadTableStats, error) { d.logger.Infof("Skipping load for table %s : %s is a datalake destination", tableName, d.Warehouse.Destination.ID) - return nil + return &types.LoadTableStats{}, nil } func (*Datalake) DeleteBy(context.Context, []string, warehouseutils.DeleteByParams) (err error) { @@ -122,10 +124,6 @@ func (*Datalake) DownloadIdentityRules(context.Context, *misc.GZipWriter) error return fmt.Errorf("datalake err :not implemented") } -func (*Datalake) GetTotalCountInTable(context.Context, string) (int64, error) { - return 0, nil -} - func (*Datalake) Connect(context.Context, model.Warehouse) (client.Client, error) { return client.Client{}, fmt.Errorf("datalake err :not implemented") } diff --git a/warehouse/integrations/deltalake/deltalake.go b/warehouse/integrations/deltalake/deltalake.go index 10299da64a..eec17336cd 100644 --- a/warehouse/integrations/deltalake/deltalake.go +++ b/warehouse/integrations/deltalake/deltalake.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" + dbsql "github.com/databricks/databricks-sql-go" dbsqllog "github.com/databricks/databricks-sql-go/logger" "golang.org/x/exp/slices" @@ -546,29 +548,31 @@ func (*Deltalake) AlterColumn(context.Context, string, string, string) (model.Al } // LoadTable loads table for table name -func (d *Deltalake) LoadTable(ctx context.Context, tableName string) error { +func (d *Deltalake) LoadTable( + ctx context.Context, + tableName string, +) (*types.LoadTableStats, error) { uploadTableSchema := d.Uploader.GetTableSchemaInUpload(tableName) warehouseTableSchema := d.Uploader.GetTableSchemaInWarehouse(tableName) - _, err := d.loadTable(ctx, tableName, uploadTableSchema, warehouseTableSchema, false) - if err != nil { - return fmt.Errorf("loading table: %w", err) - } - - return nil -} - -func (d *Deltalake) loadTable(ctx context.Context, tableName string, tableSchemaInUpload, tableSchemaAfterUpload model.TableSchema, skipTempTableDelete bool) (string, error) { - var ( - sortedColumnKeys = warehouseutils.SortColumnKeysFromColumnMap(tableSchemaInUpload) - stagingTableName = warehouseutils.StagingTableName(provider, tableName, tableNameLimit) - - err error - auth string - row *sqlmiddleware.Row + loadTableStat, _, err := d.loadTable( + ctx, + tableName, + uploadTableSchema, + warehouseTableSchema, + false, ) + return loadTableStat, err +} - d.logger.Infow("started loading", +func (d *Deltalake) loadTable( + ctx context.Context, + tableName string, + tableSchemaInUpload model.TableSchema, + tableSchemaAfterUpload model.TableSchema, + skipTempTableDelete bool, +) (*types.LoadTableStats, string, error) { + log := d.logger.With( logfield.SourceID, d.Warehouse.Source.ID, logfield.SourceType, d.Warehouse.Source.SourceDefinition.Name, logfield.DestinationID, d.Warehouse.Destination.ID, @@ -576,36 +580,84 @@ func (d *Deltalake) loadTable(ctx context.Context, tableName string, tableSchema logfield.WorkspaceID, d.Warehouse.WorkspaceID, logfield.Namespace, d.Namespace, logfield.TableName, tableName, + logfield.LoadTableStrategy, d.config.loadTableStrategy, + ) + log.Infow("started loading") + + stagingTableName := warehouseutils.StagingTableName( + provider, + tableName, + tableNameLimit, ) - if err = d.CreateTable(ctx, stagingTableName, tableSchemaAfterUpload); err != nil { - return "", fmt.Errorf("creating staging table: %w", err) + log.Debugw("creating staging table") + if err := d.CreateTable(ctx, stagingTableName, tableSchemaAfterUpload); err != nil { + return nil, "", fmt.Errorf("creating staging table: %w", err) } if !skipTempTableDelete { - defer d.dropStagingTables(ctx, []string{stagingTableName}) + defer func() { + d.dropStagingTables(ctx, []string{stagingTableName}) + }() } - if auth, err = d.authQuery(); err != nil { - return "", fmt.Errorf("getting auth query: %w", err) + log.Infow("copying data into staging table") + err := d.copyIntoLoadTable( + ctx, tableName, stagingTableName, + tableSchemaInUpload, tableSchemaAfterUpload, + ) + if err != nil { + return nil, "", fmt.Errorf("copying into staging table: %w", err) } - objectsLocation, err := d.Uploader.GetSampleLoadFileLocation(ctx, tableName) + var loadTableStat *types.LoadTableStats + if d.ShouldAppend() { + log.Infow("inserting data from staging table to main table") + loadTableStat, err = d.insertIntoLoadTable( + ctx, tableName, stagingTableName, + tableSchemaAfterUpload, + ) + } else { + log.Infow("merging data from staging table to main table") + loadTableStat, err = d.mergeIntoLoadTable( + ctx, tableName, stagingTableName, + tableSchemaInUpload, + ) + } if err != nil { - return "", fmt.Errorf("getting sample load file location: %w", err) + return nil, "", fmt.Errorf("moving data from main table to staging table: %w", err) } - var ( - loadFolder = d.getLoadFolder(objectsLocation) - tableSchemaDiff = tableSchemaDiff(tableSchemaInUpload, tableSchemaAfterUpload) - sortedColumnNames = d.sortedColumnNames(tableSchemaInUpload, sortedColumnKeys, tableSchemaDiff) + log.Infow("completed loading") - query string - partitionQuery string - ) + return loadTableStat, stagingTableName, nil +} +func (d *Deltalake) copyIntoLoadTable( + ctx context.Context, + tableName string, + stagingTableName string, + tableSchemaInUpload model.TableSchema, + tableSchemaAfterUpload model.TableSchema, +) error { + auth, err := d.authQuery() + if err != nil { + return fmt.Errorf("getting auth query: %w", err) + } + + objectsLocation, err := d.Uploader.GetSampleLoadFileLocation(ctx, tableName) + if err != nil { + return fmt.Errorf("getting sample load file location: %w", err) + } + + loadFolder := d.getLoadFolder(objectsLocation) + tableSchemaDiff := tableSchemaDiff(tableSchemaInUpload, tableSchemaAfterUpload) + sortedColumnKeys := warehouseutils.SortColumnKeysFromColumnMap(tableSchemaInUpload) + sortedColumnNames := d.sortedColumnNames(tableSchemaInUpload, sortedColumnKeys, tableSchemaDiff) + + var copyStmt string if d.Uploader.GetLoadFileType() == warehouseutils.LoadFileTypeParquet { - query = fmt.Sprintf(` + copyStmt = fmt.Sprintf(` COPY INTO %s FROM ( @@ -620,10 +672,11 @@ func (d *Deltalake) loadTable(ctx context.Context, tableName string, tableSchema %s;`, fmt.Sprintf(`%s.%s`, d.Namespace, stagingTableName), sortedColumnNames, - loadFolder, auth, + loadFolder, + auth, ) } else { - query = fmt.Sprintf(` + copyStmt = fmt.Sprintf(` COPY INTO %s FROM ( @@ -650,12 +703,19 @@ func (d *Deltalake) loadTable(ctx context.Context, tableName string, tableSchema ) } - if _, err = d.DB.ExecContext(ctx, query); err != nil { - return "", fmt.Errorf("running COPY command: %w", err) + if _, err := d.DB.ExecContext(ctx, copyStmt); err != nil { + return fmt.Errorf("executing copy query: %w", err) } + return nil +} - if d.ShouldAppend() { - query = fmt.Sprintf(` +func (d *Deltalake) insertIntoLoadTable( + ctx context.Context, + tableName string, + stagingTableName string, + tableSchemaAfterUpload model.TableSchema, +) (*types.LoadTableStats, error) { + insertStmt := fmt.Sprintf(` INSERT INTO %[1]s.%[2]s (%[4]s) SELECT %[4]s @@ -679,20 +739,45 @@ func (d *Deltalake) loadTable(ctx context.Context, tableName string, tableSchema _rudder_staging_row_number = 1 ); `, - d.Namespace, - tableName, - stagingTableName, - columnNames(warehouseutils.SortColumnKeysFromColumnMap(tableSchemaAfterUpload)), - primaryKey(tableName), - ) - } else { - if partitionQuery, err = d.partitionQuery(ctx, tableName); err != nil { - return "", fmt.Errorf("getting partition query: %w", err) - } + d.Namespace, + tableName, + stagingTableName, + columnNames(warehouseutils.SortColumnKeysFromColumnMap(tableSchemaAfterUpload)), + primaryKey(tableName), + ) - pk := primaryKey(tableName) + var rowsAffected, rowsInserted int64 + err := d.DB.QueryRowContext(ctx, insertStmt).Scan( + &rowsAffected, + &rowsInserted, + ) + if err != nil { + return nil, fmt.Errorf("executing insert query: %w", err) + } - query = fmt.Sprintf(` + return &types.LoadTableStats{ + RowsInserted: rowsInserted, + }, nil +} + +func (d *Deltalake) mergeIntoLoadTable( + ctx context.Context, + tableName string, + stagingTableName string, + tableSchemaInUpload model.TableSchema, +) (*types.LoadTableStats, error) { + sortedColumnKeys := warehouseutils.SortColumnKeysFromColumnMap( + tableSchemaInUpload, + ) + + partitionQuery, err := d.partitionQuery(ctx, tableName) + if err != nil { + return nil, fmt.Errorf("getting partition query: %w", err) + } + + pk := primaryKey(tableName) + + mergeStmt := fmt.Sprintf(` MERGE INTO %[1]s.%[2]s AS MAIN USING ( SELECT * @@ -721,59 +806,31 @@ func (d *Deltalake) loadTable(ctx context.Context, tableName string, tableSchema VALUES (%[7]s); `, - d.Namespace, - tableName, - stagingTableName, - pk, - columnsWithValues(sortedColumnKeys), - columnNames(sortedColumnKeys), - stagingColumnNames(sortedColumnKeys), - partitionQuery, - ) - } - - row = d.DB.QueryRowContext(ctx, query) - - var ( - affected int64 - updated int64 - deleted int64 - inserted int64 + d.Namespace, + tableName, + stagingTableName, + pk, + columnsWithValues(sortedColumnKeys), + columnNames(sortedColumnKeys), + stagingColumnNames(sortedColumnKeys), + partitionQuery, ) - if d.ShouldAppend() { - err = row.Scan(&affected, &inserted) - } else { - err = row.Scan(&affected, &updated, &deleted, &inserted) - } - + var rowsAffected, rowsUpdated, rowsDeleted, rowsInserted int64 + err = d.DB.QueryRowContext(ctx, mergeStmt).Scan( + &rowsAffected, + &rowsUpdated, + &rowsDeleted, + &rowsInserted, + ) if err != nil { - return "", fmt.Errorf("scanning deduplication: %w", err) + return nil, fmt.Errorf("executing merge command: %w", err) } - if row.Err() != nil { - return "", fmt.Errorf("running deduplication: %w", row.Err()) - } - - d.stats.NewTaggedStat("dedup_rows", stats.CountType, stats.Tags{ - "sourceID": d.Warehouse.Source.ID, - "sourceType": d.Warehouse.Source.SourceDefinition.Name, - "sourceCategory": d.Warehouse.Source.SourceDefinition.Category, - "destID": d.Warehouse.Destination.ID, - "destType": d.Warehouse.Destination.DestinationDefinition.Name, - "workspaceId": d.Warehouse.WorkspaceID, - "tableName": tableName, - }).Count(int(updated)) - d.logger.Infow("completed loading", - logfield.SourceID, d.Warehouse.Source.ID, - logfield.SourceType, d.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, d.Warehouse.Destination.ID, - logfield.DestinationType, d.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, d.Warehouse.WorkspaceID, - logfield.Namespace, d.Namespace, - logfield.TableName, tableName, - ) - return stagingTableName, nil + return &types.LoadTableStats{ + RowsInserted: rowsInserted, + RowsUpdated: rowsUpdated, + }, nil } func tableSchemaDiff(tableSchemaInUpload, tableSchemaAfterUpload model.TableSchema) warehouseutils.TableSchemaDiff { @@ -786,7 +843,6 @@ func tableSchemaDiff(tableSchemaInUpload, tableSchemaAfterUpload model.TableSche diff.ColumnMap[columnName] = columnType } } - return diff } @@ -896,7 +952,7 @@ func (d *Deltalake) hasAWSCredentials() bool { // partitionQuery returns a query to fetch partitions for a table func (d *Deltalake) partitionQuery(ctx context.Context, tableName string) (string, error) { - if !d.config.enablePartitionPruning { + if !d.config.enablePartitionPruning || d.Uploader.ShouldOnDedupUseNewRecord() { return "", nil } if d.Uploader.ShouldOnDedupUseNewRecord() { @@ -961,7 +1017,7 @@ func (d *Deltalake) LoadUserTables(ctx context.Context) map[string]error { logfield.Namespace, d.Namespace, ) - identifyStagingTable, err := d.loadTable(ctx, warehouseutils.IdentifiesTable, identifiesSchemaInUpload, identifiesSchemaInWarehouse, true) + _, identifyStagingTable, err := d.loadTable(ctx, warehouseutils.IdentifiesTable, identifiesSchemaInUpload, identifiesSchemaInWarehouse, true) if err != nil { return map[string]error{ warehouseutils.IdentifiesTable: fmt.Errorf("loading table %s: %w", warehouseutils.IdentifiesTable, err), @@ -1110,7 +1166,7 @@ func (d *Deltalake) LoadUserTables(ctx context.Context) map[string]error { inserted int64 ) - if d.config.loadTableStrategy == appendMode { + if d.ShouldAppend() { err = row.Scan(&affected, &inserted) } else { err = row.Scan(&affected, &updated, &deleted, &inserted) @@ -1220,27 +1276,6 @@ func (*Deltalake) DownloadIdentityRules(context.Context, *misc.GZipWriter) error return nil } -// GetTotalCountInTable returns the total count in the table -func (d *Deltalake) GetTotalCountInTable(ctx context.Context, tableName string) (int64, error) { - query := fmt.Sprintf(` - SELECT COUNT(*) FROM %[1]s.%[2]s; - `, - d.Namespace, - tableName, - ) - - var total int64 - err := d.DB.QueryRowContext(ctx, query).Scan(&total) - if err != nil { - if strings.Contains(err.Error(), schemaNotFound) { - return 0, nil - } - return 0, fmt.Errorf("total count in table: %w", err) - } - - return total, nil -} - // Connect returns Client func (d *Deltalake) Connect(_ context.Context, warehouse model.Warehouse) (warehouseclient.Client, error) { d.Warehouse = warehouse diff --git a/warehouse/integrations/deltalake/deltalake_test.go b/warehouse/integrations/deltalake/deltalake_test.go index 2eb9831f22..4302587608 100644 --- a/warehouse/integrations/deltalake/deltalake_test.go +++ b/warehouse/integrations/deltalake/deltalake_test.go @@ -12,6 +12,11 @@ import ( "testing" "time" + "golang.org/x/exp/slices" + + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + dbsql "github.com/databricks/databricks-sql-go" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" @@ -379,6 +384,634 @@ func TestIntegration(t *testing.T) { }) } }) + + t.Run("Load Table", func(t *testing.T) { + const ( + sourceID = "test_source_id" + destinationID = "test_destination_id" + workspaceID = "test_workspace_id" + ) + + namespace := testhelper.RandSchema(destType) + + t.Cleanup(func() { + require.Eventually(t, func() bool { + if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %[1]s CASCADE;`, namespace)); err != nil { + t.Logf("error deleting schema: %v", err) + return false + } + return true + }, + time.Minute, + time.Second, + ) + }) + + schemaInUpload := model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + } + schemaInWarehouse := model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + "extra_test_bool": "boolean", + "extra_test_datetime": "datetime", + "extra_test_float": "float", + "extra_test_int": "int", + "extra_test_string": "string", + } + + warehouse := model.Warehouse{ + Source: backendconfig.SourceT{ + ID: sourceID, + }, + Destination: backendconfig.DestinationT{ + ID: destinationID, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: destType, + }, + Config: map[string]any{ + "host": deltaLakeCredentials.Host, + "port": deltaLakeCredentials.Port, + "path": deltaLakeCredentials.Path, + "token": deltaLakeCredentials.Token, + "namespace": namespace, + "bucketProvider": warehouseutils.AzureBlob, + "containerName": deltaLakeCredentials.ContainerName, + "accountName": deltaLakeCredentials.AccountName, + "accountKey": deltaLakeCredentials.AccountKey, + }, + }, + WorkspaceID: workspaceID, + Namespace: namespace, + } + + fm, err := filemanager.New(&filemanager.Settings{ + Provider: warehouseutils.AzureBlob, + Config: map[string]any{ + "containerName": deltaLakeCredentials.ContainerName, + "accountName": deltaLakeCredentials.AccountName, + "accountKey": deltaLakeCredentials.AccountKey, + "bucketProvider": warehouseutils.AzureBlob, + }, + }) + require.NoError(t, err) + + t.Run("schema does not exists", func(t *testing.T) { + tableName := "schema_not_exists_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") + + d := deltalake.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("table does not exists", func(t *testing.T) { + tableName := "table_not_exists_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") + + d := deltalake.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("merge", func(t *testing.T) { + tableName := "merge_test_table" + + t.Run("without dedup", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") + + d := deltalake.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + t.Run("with dedup use new record", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, true, "2022-12-15T06:53:49.640Z") + + d := deltalake.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DedupTestRecords()) + }) + t.Run("with no overlapping partition", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-11-15T06:53:49.640Z") + + d := deltalake.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DedupTwiceTestRecords()) + }) + }) + t.Run("append", func(t *testing.T) { + tableName := "append_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, true, false, "2022-12-15T06:53:49.640Z") + + c := config.New() + c.Set("Warehouse.deltalake.loadTableStrategy", "APPEND") + + d := deltalake.New(c, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.AppendTestRecords()) + }) + t.Run("load file does not exists", func(t *testing.T) { + tableName := "load_file_not_exists_test_table" + + loadFiles := []warehouseutils.LoadFile{{ + Location: "https://account.blob.core.windows.net/container/rudder-warehouse-load-objects/load_file_not_exists_test_table/test_source_id/a01af26e-4548-49ff-a895-258829cc1a83-load_file_not_exists_test_table/load.csv.gz", + }} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") + + d := deltalake.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("mismatch in number of columns", func(t *testing.T) { + tableName := "mismatch_columns_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") + + d := deltalake.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + t.Run("mismatch in schema", func(t *testing.T) { + tableName := "mismatch_schema_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") + + d := deltalake.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.MismatchSchemaTestRecords()) + }) + t.Run("discards", func(t *testing.T) { + tableName := warehouseutils.DiscardsTable + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") + + d := deltalake.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, warehouseutils.DiscardsSchema) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(6)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + column_name, + column_value, + received_at, + row_id, + table_name, + uuid_ts + FROM + %s.%s + ORDER BY row_id ASC; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DiscardTestRecords()) + }) + t.Run("parquet", func(t *testing.T) { + tableName := "parquet_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.parquet", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeParquet, false, false, "2022-12-15T06:53:49.640Z") + + d := deltalake.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + t.Run("partition pruning", func(t *testing.T) { + t.Run("not partitioned", func(t *testing.T) { + tableName := "not_partitioned_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") + + d := deltalake.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + _, err = d.DB.QueryContext(ctx, ` + CREATE TABLE IF NOT EXISTS `+namespace+`.`+tableName+` ( + extra_test_bool BOOLEAN, + extra_test_datetime TIMESTAMP, + extra_test_float DOUBLE, + extra_test_int BIGINT, + extra_test_string STRING, + id STRING, + received_at TIMESTAMP, + event_date DATE GENERATED ALWAYS AS ( + CAST(received_at AS DATE) + ), + test_bool BOOLEAN, + test_datetime TIMESTAMP, + test_float DOUBLE, + test_int BIGINT, + test_string STRING + ) USING DELTA; + `) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + t.Run("event_date is not in partition", func(t *testing.T) { + tableName := "not_event_date_partition_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv, false, false, "2022-12-15T06:53:49.640Z") + + d := deltalake.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + _, err = d.DB.QueryContext(ctx, ` + CREATE TABLE IF NOT EXISTS `+namespace+`.`+tableName+` ( + extra_test_bool BOOLEAN, + extra_test_datetime TIMESTAMP, + extra_test_float DOUBLE, + extra_test_int BIGINT, + extra_test_string STRING, + id STRING, + received_at TIMESTAMP, + event_date DATE GENERATED ALWAYS AS ( + CAST(received_at AS DATE) + ), + test_bool BOOLEAN, + test_datetime TIMESTAMP, + test_float DOUBLE, + test_int BIGINT, + test_string STRING + ) USING DELTA PARTITIONED BY(id); + `) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + }) + }) } func TestDeltalake_TrimErrorMessage(t *testing.T) { @@ -474,6 +1107,41 @@ func TestDeltalake_ShouldAppend(t *testing.T) { } } +func newMockUploader( + t testing.TB, + loadFiles []warehouseutils.LoadFile, + tableName string, + schemaInUpload model.TableSchema, + schemaInWarehouse model.TableSchema, + loadFileType string, + canAppend bool, + onDedupUseNewRecords bool, + eventTS string, +) warehouseutils.Uploader { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + firstLastEventTS, err := time.Parse(time.RFC3339, eventTS) + require.NoError(t, err) + + mockUploader := mockuploader.NewMockUploader(ctrl) + mockUploader.EXPECT().UseRudderStorage().Return(false).AnyTimes() + mockUploader.EXPECT().ShouldOnDedupUseNewRecord().Return(onDedupUseNewRecords).AnyTimes() + mockUploader.EXPECT().CanAppend().Return(canAppend).AnyTimes() + mockUploader.EXPECT().GetLoadFilesMetadata(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, options warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { + return slices.Clone(loadFiles) + }, + ).AnyTimes() + mockUploader.EXPECT().GetSampleLoadFileLocation(gomock.Any(), gomock.Any()).Return(loadFiles[0].Location, nil).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInUpload(tableName).Return(schemaInUpload).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInWarehouse(tableName).Return(schemaInWarehouse).AnyTimes() + mockUploader.EXPECT().GetLoadFileType().Return(loadFileType).AnyTimes() + mockUploader.EXPECT().GetFirstLastEvent().Return(firstLastEventTS, firstLastEventTS).AnyTimes() + + return mockUploader +} + func mergeEventsMap() testhelper.EventsCountMap { return testhelper.EventsCountMap{ "identifies": 1, diff --git a/warehouse/integrations/manager/manager.go b/warehouse/integrations/manager/manager.go index 93e151e0d5..465a977724 100644 --- a/warehouse/integrations/manager/manager.go +++ b/warehouse/integrations/manager/manager.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" + "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/stats" @@ -31,7 +33,7 @@ type Manager interface { CreateTable(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) AlterColumn(ctx context.Context, tableName, columnName, columnType string) (model.AlterTableResponse, error) - LoadTable(ctx context.Context, tableName string) error + LoadTable(ctx context.Context, tableName string) (*types.LoadTableStats, error) LoadUserTables(ctx context.Context) map[string]error LoadIdentityMergeRulesTable(ctx context.Context) error LoadIdentityMappingsTable(ctx context.Context) error @@ -39,7 +41,6 @@ type Manager interface { IsEmpty(ctx context.Context, warehouse model.Warehouse) (bool, error) TestConnection(ctx context.Context, warehouse model.Warehouse) error DownloadIdentityRules(ctx context.Context, gzWriter *misc.GZipWriter) error - GetTotalCountInTable(ctx context.Context, tableName string) (int64, error) Connect(ctx context.Context, warehouse model.Warehouse) (client.Client, error) LoadTestTable(ctx context.Context, location, stagingTableName string, payloadMap map[string]interface{}, loadFileFormat string) error SetConnectionTimeout(timeout time.Duration) diff --git a/warehouse/integrations/mssql/mssql.go b/warehouse/integrations/mssql/mssql.go index d8de03e27e..9dedb8a583 100644 --- a/warehouse/integrations/mssql/mssql.go +++ b/warehouse/integrations/mssql/mssql.go @@ -18,6 +18,10 @@ import ( "unicode/utf16" "unicode/utf8" + "github.com/samber/lo" + + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" + "github.com/rudderlabs/rudder-go-kit/stats" sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/rudderlabs/rudder-server/warehouse/logfield" @@ -45,9 +49,9 @@ const ( ) const ( - mssqlStringLengthLimit = 512 - provider = warehouseutils.MSSQL - tableNameLimit = 127 + stringLengthLimit = 512 + provider = warehouseutils.MSSQL + tableNameLimit = 127 ) var rudderDataTypesMapToMssql = map[string]string{ @@ -181,7 +185,14 @@ func (ms *MSSQL) connect() (*sqlmw.DB, error) { db, sqlmw.WithStats(ms.stats), sqlmw.WithLogger(ms.logger), - sqlmw.WithKeyAndValues(ms.defaultLogFields()), + sqlmw.WithKeyAndValues([]any{ + logfield.SourceID, ms.Warehouse.Source.ID, + logfield.SourceType, ms.Warehouse.Source.SourceDefinition.Name, + logfield.DestinationID, ms.Warehouse.Destination.ID, + logfield.DestinationType, ms.Warehouse.Destination.DestinationDefinition.Name, + logfield.WorkspaceID, ms.Warehouse.WorkspaceID, + logfield.Namespace, ms.Namespace, + }), sqlmw.WithQueryTimeout(ms.connectTimeout), sqlmw.WithSlowQueryThreshold(ms.config.slowQueryThreshold), ) @@ -200,23 +211,11 @@ func (ms *MSSQL) connectionCredentials() *credentials { } } -func (ms *MSSQL) defaultLogFields() []any { - return []any{ - logfield.SourceID, ms.Warehouse.Source.ID, - logfield.SourceType, ms.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, ms.Warehouse.Destination.ID, - logfield.DestinationType, ms.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, ms.Warehouse.WorkspaceID, - logfield.Namespace, ms.Namespace, - } -} - func ColumnsWithDataTypes(columns model.TableSchema, prefix string) string { - var arr []string - for name, dataType := range columns { - arr = append(arr, fmt.Sprintf(`"%s%s" %s`, prefix, name, rudderDataTypesMapToMssql[dataType])) - } - return strings.Join(arr, ",") + formattedColumns := lo.MapToSlice(columns, func(name, dataType string) string { + return fmt.Sprintf(`"%s%s" %s`, prefix, name, rudderDataTypesMapToMssql[dataType]) + }) + return strings.Join(formattedColumns, ",") } func (*MSSQL) IsEmpty(context.Context, model.Warehouse) (empty bool, err error) { @@ -255,227 +254,344 @@ func (ms *MSSQL) DeleteBy(ctx context.Context, tableNames []string, params wareh return nil } -func (ms *MSSQL) loadTable(ctx context.Context, tableName string, tableSchemaInUpload model.TableSchema, skipTempTableDelete bool) (stagingTableName string, err error) { - ms.logger.Infof("MSSQL: Starting load for table:%s", tableName) - - // sort column names - sortedColumnKeys := warehouseutils.SortColumnKeysFromColumnMap(tableSchemaInUpload) +func (ms *MSSQL) loadTable( + ctx context.Context, + tableName string, + tableSchemaInUpload model.TableSchema, + skipTempTableDelete bool, +) (*types.LoadTableStats, string, error) { + log := ms.logger.With( + logfield.SourceID, ms.Warehouse.Source.ID, + logfield.SourceType, ms.Warehouse.Source.SourceDefinition.Name, + logfield.DestinationID, ms.Warehouse.Destination.ID, + logfield.DestinationType, ms.Warehouse.Destination.DestinationDefinition.Name, + logfield.WorkspaceID, ms.Warehouse.WorkspaceID, + logfield.Namespace, ms.Namespace, + logfield.TableName, tableName, + ) + log.Infow("started loading") fileNames, err := ms.LoadFileDownLoader.Download(ctx, tableName) - defer misc.RemoveFilePaths(fileNames...) if err != nil { - return + return nil, "", fmt.Errorf("downloading load files: %w", err) } + defer func() { + misc.RemoveFilePaths(fileNames...) + }() - txn, err := ms.DB.BeginTx(ctx, &sql.TxOptions{}) - if err != nil { - ms.logger.Errorf("MSSQL: Error while beginning a transaction in db for loading in table:%s: %v", tableName, err) - return - } - // create temporary table - stagingTableName = warehouseutils.StagingTableName(provider, tableName, tableNameLimit) - // prepared stmts cannot be used to create temp objects here. Will work in a txn, but will be purged after commit. - // https://github.com/denisenkom/go-mssqldb/issues/149, https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175528(v=sql.105)?redirectedfrom=MSDN - // sqlStatement := fmt.Sprintf(`CREATE TABLE ##%[2]s like %[1]s.%[3]s`, ms.Namespace, stagingTableName, tableName) - // Hence falling back to creating normal tables - sqlStatement := fmt.Sprintf(`select top 0 * into %[1]s.%[2]s from %[1]s.%[3]s`, ms.Namespace, stagingTableName, tableName) + stagingTableName := warehouseutils.StagingTableName( + provider, + tableName, + tableNameLimit, + ) - ms.logger.Debugf("MSSQL: Creating temporary table for table:%s at %s\n", tableName, sqlStatement) - _, err = txn.ExecContext(ctx, sqlStatement) - if err != nil { - ms.logger.Errorf("MSSQL: Error creating temporary table for table:%s: %v\n", tableName, err) - _ = txn.Rollback() - return + // The use of prepared statements for creating temporary tables is not suitable in this context. + // Temporary tables in SQL Server have a limited scope and are automatically purged after the transaction commits. + // Therefore, creating normal tables is chosen as an alternative. + // + // For more information on this behavior: + // - See the discussion at https://github.com/denisenkom/go-mssqldb/issues/149 regarding prepared statements. + // - Refer to Microsoft's documentation on temporary tables at + // https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175528(v=sql.105)?redirectedfrom=MSDN. + log.Debugw("creating staging table") + createStagingTableStmt := fmt.Sprintf(` + SELECT + TOP 0 * INTO %[1]s.%[2]s + FROM + %[1]s.%[3]s;`, + ms.Namespace, + stagingTableName, + tableName, + ) + if _, err = ms.DB.ExecContext(ctx, createStagingTableStmt); err != nil { + return nil, "", fmt.Errorf("creating temporary table: %w", err) } + if !skipTempTableDelete { - defer ms.dropStagingTable(ctx, stagingTableName) + defer func() { + ms.dropStagingTable(ctx, stagingTableName) + }() } - stmt, err := txn.PrepareContext(ctx, mssql.CopyIn(ms.Namespace+"."+stagingTableName, mssql.BulkOptions{CheckConstraints: false}, sortedColumnKeys...)) + txn, err := ms.DB.BeginTx(ctx, &sql.TxOptions{}) if err != nil { - ms.logger.Errorf("MSSQL: Error while preparing statement for transaction in db for loading in staging table:%s: %v\nstmt: %v", stagingTableName, err, stmt) - _ = txn.Rollback() - return + return nil, "", fmt.Errorf("begin transaction: %w", err) } - for _, objectFileName := range fileNames { - var gzipFile *os.File - gzipFile, err = os.Open(objectFileName) + defer func() { if err != nil { - ms.logger.Errorf("MSSQL: Error opening file using os.Open for file:%s while loading to table %s", objectFileName, tableName) _ = txn.Rollback() - return } + }() - var gzipReader *gzip.Reader - gzipReader, err = gzip.NewReader(gzipFile) - if err != nil { - ms.logger.Errorf("MSSQL: Error reading file using gzip.NewReader for file:%s while loading to table %s", gzipFile, tableName) - gzipFile.Close() - _ = txn.Rollback() - return + sortedColumnKeys := warehouseutils.SortColumnKeysFromColumnMap( + tableSchemaInUpload, + ) + log.Debugw("creating prepared stmt for loading data") + copyInStmt := mssql.CopyIn(ms.Namespace+"."+stagingTableName, mssql.BulkOptions{CheckConstraints: false}, + sortedColumnKeys..., + ) + stmt, err := txn.PrepareContext(ctx, copyInStmt) + if err != nil { + return nil, "", fmt.Errorf("preparing copyIn statement: %w", err) + } + + log.Infow("loading data into staging table") + for _, fileName := range fileNames { + err = ms.loadDataIntoStagingTable( + ctx, log, stmt, + fileName, sortedColumnKeys, + tableSchemaInUpload, + ) + if err != nil { + return nil, "", fmt.Errorf("loading data into staging table: %w", err) } - csvReader := csv.NewReader(gzipReader) - var csvRowsProcessedCount int - for { - var record []string - record, err = csvReader.Read() - if err != nil { - if err == io.EOF { - ms.logger.Debugf("MSSQL: File reading completed while reading csv file for loading in staging table:%s: %s", stagingTableName, objectFileName) - break - } - ms.logger.Errorf("MSSQL: Error while reading csv file %s for loading in staging table:%s: %v", objectFileName, stagingTableName, err) - _ = txn.Rollback() - return - } - if len(sortedColumnKeys) != len(record) { - err = fmt.Errorf(`load file CSV columns for a row mismatch number found in upload schema. Columns in CSV row: %d, Columns in upload schema of table-%s: %d. Processed rows in csv file until mismatch: %d`, len(record), tableName, len(sortedColumnKeys), csvRowsProcessedCount) - ms.logger.Error(err) - _ = txn.Rollback() - return + } + if _, err = stmt.ExecContext(ctx); err != nil { + return nil, "", fmt.Errorf("executing copyIn statement: %w", err) + } + + log.Infow("deleting from load table") + rowsDeleted, err := ms.deleteFromLoadTable( + ctx, txn, tableName, + stagingTableName, + ) + if err != nil { + return nil, "", fmt.Errorf("delete from load table: %w", err) + } + + log.Infow("inserting into load table") + rowsInserted, err := ms.insertIntoLoadTable( + ctx, txn, tableName, + stagingTableName, sortedColumnKeys, + ) + if err != nil { + return nil, "", fmt.Errorf("insert into: %w", err) + } + + log.Debugw("committing transaction") + if err = txn.Commit(); err != nil { + return nil, "", fmt.Errorf("commit transaction: %w", err) + } + + log.Infow("completed loading") + + return &types.LoadTableStats{ + RowsInserted: rowsInserted - rowsDeleted, + RowsUpdated: rowsDeleted, + }, stagingTableName, nil +} + +func (ms *MSSQL) loadDataIntoStagingTable( + ctx context.Context, + log logger.Logger, + stmt *sql.Stmt, + fileName string, + sortedColumnKeys []string, + tableSchemaInUpload model.TableSchema, +) error { + gzipFile, err := os.Open(fileName) + if err != nil { + return fmt.Errorf("opening file: %w", err) + } + defer func() { + _ = gzipFile.Close() + }() + + gzipReader, err := gzip.NewReader(gzipFile) + if err != nil { + return fmt.Errorf("reading file: %w", err) + } + defer func() { + _ = gzipReader.Close() + }() + + csvReader := csv.NewReader(gzipReader) + + for { + var record []string + record, err = csvReader.Read() + if err != nil { + if errors.Is(err, io.EOF) { + break } - var recordInterface []interface{} - for _, value := range record { - if strings.TrimSpace(value) == "" { - recordInterface = append(recordInterface, nil) - } else { - recordInterface = append(recordInterface, value) - } + return fmt.Errorf("reading file: %w", err) + } + if len(sortedColumnKeys) != len(record) { + return fmt.Errorf("mismatch in number of columns: actual count: %d, expected count: %d", + len(record), + len(sortedColumnKeys), + ) + } + + recordInterface := make([]interface{}, 0, len(record)) + for _, value := range record { + if strings.TrimSpace(value) == "" { + recordInterface = append(recordInterface, nil) + } else { + recordInterface = append(recordInterface, value) } - var finalColumnValues []interface{} - for index, value := range recordInterface { - valueType := tableSchemaInUpload[sortedColumnKeys[index]] - if value == nil { - ms.logger.Debugf("MS : Found nil value for type : %s, column : %s", valueType, sortedColumnKeys[index]) - finalColumnValues = append(finalColumnValues, nil) - continue - } - strValue := value.(string) - switch valueType { - case "int": - var convertedValue int - if convertedValue, err = strconv.Atoi(strValue); err != nil { - ms.logger.Errorf("MS : Mismatch in datatype for type : %s, column : %s, value : %s, err : %v", valueType, sortedColumnKeys[index], strValue, err) - finalColumnValues = append(finalColumnValues, nil) - } else { - finalColumnValues = append(finalColumnValues, convertedValue) - } - case "float": - var convertedValue float64 - if convertedValue, err = strconv.ParseFloat(strValue, 64); err != nil { - ms.logger.Errorf("MS : Mismatch in datatype for type : %s, column : %s, value : %s, err : %v", valueType, sortedColumnKeys[index], strValue, err) - finalColumnValues = append(finalColumnValues, nil) - } else { - finalColumnValues = append(finalColumnValues, convertedValue) - } - case "datetime": - var convertedValue time.Time - // TODO : handling milli? - if convertedValue, err = time.Parse(time.RFC3339, strValue); err != nil { - ms.logger.Errorf("MS : Mismatch in datatype for type : %s, column : %s, value : %s, err : %v", valueType, sortedColumnKeys[index], strValue, err) - finalColumnValues = append(finalColumnValues, nil) - } else { - finalColumnValues = append(finalColumnValues, convertedValue) - } - // TODO : handling all cases? - case "boolean": - var convertedValue bool - if convertedValue, err = strconv.ParseBool(strValue); err != nil { - ms.logger.Errorf("MS : Mismatch in datatype for type : %s, column : %s, value : %s, err : %v", valueType, sortedColumnKeys[index], strValue, err) - finalColumnValues = append(finalColumnValues, nil) - } else { - finalColumnValues = append(finalColumnValues, convertedValue) - } - case "string": - // This is needed to enable diacritic support Ex: Ü,ç Ç,©,∆,ß,á,ù,ñ,ê - // A substitute to this PR; https://github.com/denisenkom/go-mssqldb/pull/576/files - // An alternate to this approach is to use nvarchar(instead of varchar) - if len(strValue) > mssqlStringLengthLimit { - strValue = strValue[:mssqlStringLengthLimit] - } - var byteArr []byte - if hasDiacritics(strValue) { - ms.logger.Debug("diacritics " + strValue) - byteArr = str2ucs2(strValue) - // This is needed as with above operation every character occupies 2 bytes - if len(byteArr) > mssqlStringLengthLimit { - byteArr = byteArr[:mssqlStringLengthLimit] - } - finalColumnValues = append(finalColumnValues, byteArr) - } else { - ms.logger.Debug("non-diacritic : " + strValue) - finalColumnValues = append(finalColumnValues, strValue) - } - default: - finalColumnValues = append(finalColumnValues, value) - } + } + + finalColumnValues := make([]interface{}, 0, len(record)) + for index, value := range recordInterface { + valueType := tableSchemaInUpload[sortedColumnKeys[index]] + if value == nil { + log.Warnw("found nil value", + logfield.ColumnType, valueType, + logfield.ColumnName, sortedColumnKeys[index], + ) + + finalColumnValues = append(finalColumnValues, nil) + continue } - _, err = stmt.ExecContext(ctx, finalColumnValues...) + processedVal, err := ms.ProcessColumnValue( + value.(string), + valueType, + ) if err != nil { - ms.logger.Errorf("MSSQL: Error in exec statement for loading in staging table:%s: %v", stagingTableName, err) - _ = txn.Rollback() - return + log.Warnw("mismatch in datatype", + logfield.ColumnType, valueType, + logfield.ColumnName, sortedColumnKeys[index], + logfield.ColumnValue, value, + logfield.Error, err, + ) + finalColumnValues = append(finalColumnValues, nil) + } else { + finalColumnValues = append(finalColumnValues, processedVal) } - csvRowsProcessedCount++ } - _ = gzipReader.Close() - gzipFile.Close() - } - _, err = stmt.ExecContext(ctx) - if err != nil { - _ = txn.Rollback() - ms.logger.Errorf("MSSQL: Rollback transaction as there was error while loading staging table:%s: %v", stagingTableName, err) - return + _, err = stmt.ExecContext(ctx, finalColumnValues...) + if err != nil { + return fmt.Errorf("exec statement error: %w", err) + } + } + return nil +} +func (as *MSSQL) ProcessColumnValue( + value string, + valueType string, +) (interface{}, error) { + switch valueType { + case model.IntDataType: + return strconv.Atoi(value) + case model.FloatDataType: + return strconv.ParseFloat(value, 64) + case model.DateTimeDataType: + return time.Parse(time.RFC3339, value) + case model.BooleanDataType: + return strconv.ParseBool(value) + case model.StringDataType: + if len(value) > stringLengthLimit { + value = value[:stringLengthLimit] + } + if !hasDiacritics(value) { + return value, nil + } else { + byteArr := str2ucs2(value) + if len(byteArr) > stringLengthLimit { + byteArr = byteArr[:stringLengthLimit] + } + return byteArr, nil + } + default: + return value, nil } - // deduplication process +} + +func (ms *MSSQL) deleteFromLoadTable( + ctx context.Context, + txn *sqlmw.Tx, + tableName string, + stagingTableName string, +) (int64, error) { primaryKey := "id" if column, ok := primaryKeyMap[tableName]; ok { primaryKey = column } - partitionKey := "id" - if column, ok := partitionKeyMap[tableName]; ok { - partitionKey = column - } - var additionalJoinClause string + + var additionalDeleteStmtClause string if tableName == warehouseutils.DiscardsTable { - additionalJoinClause = fmt.Sprintf(`AND _source.%[3]s = "%[1]s"."%[2]s"."%[3]s" AND _source.%[4]s = "%[1]s"."%[2]s"."%[4]s"`, ms.Namespace, tableName, "table_name", "column_name") - } - sqlStatement = fmt.Sprintf(`DELETE FROM "%[1]s"."%[2]s" FROM "%[1]s"."%[3]s" as _source where (_source.%[4]s = "%[1]s"."%[2]s"."%[4]s" %[5]s)`, ms.Namespace, tableName, stagingTableName, primaryKey, additionalJoinClause) - ms.logger.Infof("MSSQL: Deduplicate records for table:%s using staging table: %s\n", tableName, sqlStatement) - _, err = txn.ExecContext(ctx, sqlStatement) - if err != nil { - ms.logger.Errorf("MSSQL: Error deleting from original table for dedup: %v\n", err) - _ = txn.Rollback() - return + additionalDeleteStmtClause = fmt.Sprintf(`AND _source.%[3]s = %[1]q.%[2]q.%[3]q AND _source.%[4]s = %[1]q.%[2]q.%[4]q`, + ms.Namespace, + tableName, + "table_name", + "column_name", + ) } - quotedColumnNames := warehouseutils.DoubleQuoteAndJoinByComma(sortedColumnKeys) - sqlStatement = fmt.Sprintf(`INSERT INTO "%[1]s"."%[2]s" (%[3]s) - SELECT %[3]s FROM ( - SELECT *, row_number() OVER (PARTITION BY %[5]s ORDER BY received_at DESC) AS _rudder_staging_row_number FROM "%[1]s"."%[4]s" - ) AS _ where _rudder_staging_row_number = 1 - `, ms.Namespace, tableName, quotedColumnNames, stagingTableName, partitionKey) - ms.logger.Infof("MSSQL: Inserting records for table:%s using staging table: %s\n", tableName, sqlStatement) - _, err = txn.ExecContext(ctx, sqlStatement) + deleteStmt := fmt.Sprintf(` + DELETE FROM + %[1]q.%[2]q + FROM + %[1]q.%[3]q AS _source + WHERE + ( + _source.%[4]s = %[1]q.%[2]q.%[4]q %[5]s + );`, + ms.Namespace, + tableName, + stagingTableName, + primaryKey, + additionalDeleteStmtClause, + ) + r, err := txn.ExecContext(ctx, deleteStmt) if err != nil { - ms.logger.Errorf("MSSQL: Error inserting into original table: %v\n", err) - _ = txn.Rollback() - return + return 0, fmt.Errorf("deleting from main table: %w", err) } + return r.RowsAffected() +} - if err = txn.Commit(); err != nil { - ms.logger.Errorf("MSSQL: Error while committing transaction as there was error while loading staging table:%s: %v", stagingTableName, err) - _ = txn.Rollback() - return +func (ms *MSSQL) insertIntoLoadTable( + ctx context.Context, + txn *sqlmw.Tx, + tableName string, + stagingTableName string, + sortedColumnKeys []string, +) (int64, error) { + partitionKey := "id" + if column, ok := partitionKeyMap[tableName]; ok { + partitionKey = column } - ms.logger.Infof("MSSQL: Complete load for table:%s", tableName) - return + quotedColumnNames := warehouseutils.DoubleQuoteAndJoinByComma( + sortedColumnKeys, + ) + + insertStmt := fmt.Sprintf(` + INSERT INTO %[1]q.%[2]q (%[3]s) + SELECT + %[3]s + FROM + ( + SELECT + *, + ROW_NUMBER() OVER ( + PARTITION BY %[5]s + ORDER BY + received_at DESC + ) AS _rudder_staging_row_number + FROM + %[1]q.%[4]q + ) AS _ + WHERE + _rudder_staging_row_number = 1;`, + ms.Namespace, + tableName, + quotedColumnNames, + stagingTableName, + partitionKey, + ) + + r, err := txn.ExecContext(ctx, insertStmt) + if err != nil { + return 0, fmt.Errorf("inserting into main table: %w", err) + } + return r.RowsAffected() } // Taken from https://github.com/denisenkom/go-mssqldb/blob/master/tds.go @@ -501,7 +617,7 @@ func hasDiacritics(str string) bool { func (ms *MSSQL) loadUserTables(ctx context.Context) (errorMap map[string]error) { errorMap = map[string]error{warehouseutils.IdentifiesTable: nil} ms.logger.Infof("MSSQL: Starting load for identifies and users tables\n") - identifyStagingTable, err := ms.loadTable(ctx, warehouseutils.IdentifiesTable, ms.Uploader.GetTableSchemaInUpload(warehouseutils.IdentifiesTable), true) + _, identifyStagingTable, err := ms.loadTable(ctx, warehouseutils.IdentifiesTable, ms.Uploader.GetTableSchemaInUpload(warehouseutils.IdentifiesTable), true) if err != nil { errorMap[warehouseutils.IdentifiesTable] = err return @@ -590,7 +706,7 @@ func (ms *MSSQL) loadUserTables(ctx context.Context) (errorMap map[string]error) ms.logger.Infof("MSSQL: Dedup records for table:%s using staging table: %s\n", warehouseutils.UsersTable, sqlStatement) _, err = tx.ExecContext(ctx, sqlStatement) if err != nil { - ms.logger.Errorf("MSSQL: Error deleting from original table for dedup: %v\n", err) + ms.logger.Errorf("MSSQL: Error deleting from main table for dedup: %v\n", err) _ = tx.Rollback() errorMap[warehouseutils.UsersTable] = err return @@ -619,11 +735,10 @@ func (ms *MSSQL) loadUserTables(ctx context.Context) (errorMap map[string]error) func (ms *MSSQL) CreateSchema(ctx context.Context) (err error) { sqlStatement := fmt.Sprintf(`IF NOT EXISTS ( SELECT * FROM sys.schemas WHERE name = N'%s' ) - EXEC('CREATE SCHEMA [%s]'); -`, ms.Namespace, ms.Namespace) + EXEC('CREATE SCHEMA [%s]');`, ms.Namespace, ms.Namespace) ms.logger.Infof("MSSQL: Creating schema name in mssql for MSSQL:%s : %v", ms.Warehouse.Destination.ID, sqlStatement) _, err = ms.DB.ExecContext(ctx, sqlStatement) - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil } return @@ -675,8 +790,7 @@ func (ms *MSSQL) AddColumns(ctx context.Context, tableName string, columnsInfo [ WHERE OBJECT_ID = OBJECT_ID(N'%[1]s.%[2]s') AND name = '%[3]s' - ) -`, + )`, ms.Namespace, tableName, columnsInfo[0].Name, @@ -794,8 +908,7 @@ func (ms *MSSQL) FetchSchema(ctx context.Context) (model.Schema, model.Schema, e INFORMATION_SCHEMA.COLUMNS WHERE table_schema = @schema - and table_name not like @prefix -` + and table_name not like @prefix` rows, err := ms.DB.QueryContext(ctx, sqlStatement, sql.Named("schema", ms.Namespace), sql.Named("prefix", fmt.Sprintf("%s%%", warehouseutils.StagingTablePrefix(provider))), @@ -840,9 +953,14 @@ func (ms *MSSQL) LoadUserTables(ctx context.Context) map[string]error { return ms.loadUserTables(ctx) } -func (ms *MSSQL) LoadTable(ctx context.Context, tableName string) error { - _, err := ms.loadTable(ctx, tableName, ms.Uploader.GetTableSchemaInUpload(tableName), false) - return err +func (ms *MSSQL) LoadTable(ctx context.Context, tableName string) (*types.LoadTableStats, error) { + loadTableStat, _, err := ms.loadTable( + ctx, + tableName, + ms.Uploader.GetTableSchemaInUpload(tableName), + false, + ) + return loadTableStat, err } func (ms *MSSQL) Cleanup(ctx context.Context) { @@ -865,22 +983,6 @@ func (*MSSQL) DownloadIdentityRules(context.Context, *misc.GZipWriter) (err erro return } -func (ms *MSSQL) GetTotalCountInTable(ctx context.Context, tableName string) (int64, error) { - var ( - total int64 - err error - sqlStatement string - ) - sqlStatement = fmt.Sprintf(` - SELECT count(*) FROM "%[1]s"."%[2]s"; - `, - ms.Namespace, - tableName, - ) - err = ms.DB.QueryRowContext(ctx, sqlStatement).Scan(&total) - return total, err -} - func (ms *MSSQL) Connect(_ context.Context, warehouse model.Warehouse) (client.Client, error) { ms.Warehouse = warehouse ms.Namespace = warehouse.Namespace diff --git a/warehouse/integrations/mssql/mssql_test.go b/warehouse/integrations/mssql/mssql_test.go index d074f64241..7d4bd15953 100644 --- a/warehouse/integrations/mssql/mssql_test.go +++ b/warehouse/integrations/mssql/mssql_test.go @@ -6,9 +6,20 @@ import ( "fmt" "os" "strconv" + "strings" "testing" "time" + "github.com/golang/mock/gomock" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-server/warehouse/integrations/mssql" + mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/rudderlabs/compose-test/compose" "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" @@ -68,6 +79,7 @@ func TestIntegration(t *testing.T) { bucketName := "testbucket" accessKeyID := "MYACCESSKEY" secretAccessKey := "MYSECRETKEY" + region := "us-east-1" minioEndpoint := fmt.Sprintf("localhost:%d", minioPort) @@ -100,11 +112,8 @@ func TestIntegration(t *testing.T) { t.Setenv("MINIO_SECRET_ACCESS_KEY", secretAccessKey) t.Setenv("MINIO_MINIO_ENDPOINT", minioEndpoint) t.Setenv("MINIO_SSL", "false") - t.Setenv("RSERVER_WAREHOUSE_MSSQL_MAX_PARALLEL_LOADS", "8") - t.Setenv("RSERVER_WAREHOUSE_MSSQL_ENABLE_DELETE_BY_JOBS", "true") t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) - t.Setenv("RSERVER_WAREHOUSE_MSSQL_SLOW_QUERY_THRESHOLD", "0s") svcDone := make(chan struct{}) @@ -123,6 +132,10 @@ func TestIntegration(t *testing.T) { health.WaitUntilReady(ctx, t, serviceHealthEndpoint, time.Minute, time.Second, "serviceHealthEndpoint") t.Run("Events flow", func(t *testing.T) { + t.Setenv("RSERVER_WAREHOUSE_MSSQL_SLOW_QUERY_THRESHOLD", "0s") + t.Setenv("RSERVER_WAREHOUSE_MSSQL_MAX_PARALLEL_LOADS", "8") + t.Setenv("RSERVER_WAREHOUSE_MSSQL_ENABLE_DELETE_BY_JOBS", "true") + jobsDB := testhelper.JobsDB(t, jobsDBPort) dsn := fmt.Sprintf("sqlserver://%s:%s@%s:%d?TrustServerCertificate=true&database=%s&encrypt=disable", @@ -282,4 +295,455 @@ func TestIntegration(t *testing.T) { } testhelper.VerifyConfigurationTest(t, dest) }) + + t.Run("Load Table", func(t *testing.T) { + const ( + sourceID = "test_source_id" + destinationID = "test_destination_id" + workspaceID = "test_workspace_id" + ) + + namespace := testhelper.RandSchema(destType) + + schemaInUpload := model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + } + schemaInWarehouse := model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + "extra_test_bool": "boolean", + "extra_test_datetime": "datetime", + "extra_test_float": "float", + "extra_test_int": "int", + "extra_test_string": "string", + } + + warehouse := model.Warehouse{ + Source: backendconfig.SourceT{ + ID: sourceID, + }, + Destination: backendconfig.DestinationT{ + ID: destinationID, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: destType, + }, + Config: map[string]any{ + "host": host, + "database": database, + "user": user, + "password": password, + "port": strconv.Itoa(mssqlPort), + "sslMode": "disable", + "namespace": "", + "bucketProvider": "MINIO", + "bucketName": bucketName, + "accessKeyID": accessKeyID, + "secretAccessKey": secretAccessKey, + "useSSL": false, + "endPoint": minioEndpoint, + "syncFrequency": "30", + "useRudderStorage": false, + }, + }, + WorkspaceID: workspaceID, + Namespace: namespace, + } + + fm, err := filemanager.New(&filemanager.Settings{ + Provider: warehouseutils.MINIO, + Config: map[string]any{ + "bucketName": bucketName, + "accessKeyID": accessKeyID, + "secretAccessKey": secretAccessKey, + "endPoint": minioEndpoint, + "forcePathStyle": true, + "s3ForcePathStyle": true, + "disableSSL": true, + "region": region, + "enableSSE": false, + "bucketProvider": warehouseutils.MINIO, + }, + }) + require.NoError(t, err) + + t.Run("schema does not exists", func(t *testing.T) { + tableName := "schema_not_exists_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + ms := mssql.New(config.Default, logger.NOP, stats.Default) + err := ms.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + loadTableStat, err := ms.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("table does not exists", func(t *testing.T) { + tableName := "table_not_exists_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + ms := mssql.New(config.Default, logger.NOP, stats.Default) + err := ms.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = ms.CreateSchema(ctx) + require.NoError(t, err) + + loadTableStat, err := ms.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("merge", func(t *testing.T) { + tableName := "merge_test_table" + + t.Run("without dedup", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + ms := mssql.New(config.Default, logger.NOP, stats.Default) + err := ms.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = ms.CreateSchema(ctx) + require.NoError(t, err) + + err = ms.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := ms.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = ms.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, ms.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + cast(test_float AS float) AS test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + t.Run("with dedup", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + ms := mssql.New(config.Default, logger.NOP, stats.Default) + err := ms.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = ms.CreateSchema(ctx) + require.NoError(t, err) + + err = ms.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := ms.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, ms.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + cast(test_float AS float) AS test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DedupTestRecords()) + }) + }) + t.Run("load file does not exists", func(t *testing.T) { + tableName := "load_file_not_exists_test_table" + + loadFiles := []warehouseutils.LoadFile{{ + Location: "http://localhost:1234/testbucket/rudder-warehouse-load-objects/load_file_not_exists_test_table/test_source_id/f31af97e-03e8-46d0-8a1a-1786cb85b22c-load_file_not_exists_test_table/load.csv.gz", + }} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + ms := mssql.New(config.Default, logger.NOP, stats.Default) + err := ms.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = ms.CreateSchema(ctx) + require.NoError(t, err) + + err = ms.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := ms.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("mismatch in number of columns", func(t *testing.T) { + tableName := "mismatch_columns_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + ms := mssql.New(config.Default, logger.NOP, stats.Default) + err := ms.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = ms.CreateSchema(ctx) + require.NoError(t, err) + + err = ms.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := ms.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("mismatch in schema", func(t *testing.T) { + tableName := "mismatch_schema_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + ms := mssql.New(config.Default, logger.NOP, stats.Default) + err := ms.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = ms.CreateSchema(ctx) + require.NoError(t, err) + + err = ms.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := ms.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, ms.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + cast(test_float AS float) AS test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.MismatchSchemaTestRecords()) + }) + t.Run("discards", func(t *testing.T) { + tableName := warehouseutils.DiscardsTable + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema) + + ms := mssql.New(config.Default, logger.NOP, stats.Default) + err := ms.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = ms.CreateSchema(ctx) + require.NoError(t, err) + + err = ms.CreateTable(ctx, tableName, warehouseutils.DiscardsSchema) + require.NoError(t, err) + + loadTableStat, err := ms.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(6)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, ms.DB.DB, + fmt.Sprintf(` + SELECT + column_name, + column_value, + received_at, + row_id, + table_name, + uuid_ts + FROM + %q.%q + ORDER BY row_id ASC; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DiscardTestRecords()) + }) + }) +} + +func TestMSSQL_ProcessColumnValue(t *testing.T) { + testCases := []struct { + name string + data string + dataType string + expectedValue interface{} + wantError bool + }{ + { + name: "invalid integer", + data: "1.01", + dataType: model.IntDataType, + wantError: true, + }, + { + name: "valid integer", + data: "1", + dataType: model.IntDataType, + expectedValue: int64(1), + }, + { + name: "invalid float", + data: "test", + dataType: model.FloatDataType, + wantError: true, + }, + { + name: "valid float", + data: "1.01", + dataType: model.FloatDataType, + expectedValue: float64(1.01), + }, + { + name: "invalid boolean", + data: "test", + dataType: model.BooleanDataType, + wantError: true, + }, + { + name: "valid boolean", + data: "true", + dataType: model.BooleanDataType, + expectedValue: true, + }, + { + name: "invalid datetime", + data: "1", + dataType: model.DateTimeDataType, + wantError: true, + }, + { + name: "valid datetime", + data: "2020-01-01T00:00:00Z", + dataType: model.DateTimeDataType, + expectedValue: time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "valid string", + data: "test", + dataType: model.StringDataType, + expectedValue: "test", + }, + { + name: "valid string exceeding max length", + data: strings.Repeat("test", 200), + dataType: model.StringDataType, + expectedValue: strings.Repeat("test", 128), + }, + { + name: "valid string with diacritics", + data: "tést", + dataType: model.StringDataType, + expectedValue: []byte{0x74, 0x0, 0xe9, 0x0, 0x73, 0x0, 0x74, 0x0}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ms := mssql.New(config.Default, logger.NOP, stats.Default) + + value, err := ms.ProcessColumnValue(tc.data, tc.dataType) + if tc.wantError { + require.Error(t, err) + return + } + require.EqualValues(t, tc.expectedValue, value) + require.NoError(t, err) + }) + } +} + +func newMockUploader( + t testing.TB, + loadFiles []warehouseutils.LoadFile, + tableName string, + schemaInUpload model.TableSchema, + schemaInWarehouse model.TableSchema, +) warehouseutils.Uploader { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockUploader := mockuploader.NewMockUploader(ctrl) + mockUploader.EXPECT().UseRudderStorage().Return(false).AnyTimes() + mockUploader.EXPECT().GetLoadFilesMetadata(gomock.Any(), gomock.Any()).Return(loadFiles).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInUpload(tableName).Return(schemaInUpload).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInWarehouse(tableName).Return(schemaInWarehouse).AnyTimes() + + return mockUploader } diff --git a/warehouse/integrations/postgres/load.go b/warehouse/integrations/postgres/load.go index 48c1f1d5f1..4f9471a85c 100644 --- a/warehouse/integrations/postgres/load.go +++ b/warehouse/integrations/postgres/load.go @@ -3,6 +3,7 @@ package postgres import ( "compress/gzip" "context" + "database/sql" "encoding/csv" "errors" "fmt" @@ -10,6 +11,8 @@ import ( "os" "strings" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" + sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/rudderlabs/rudder-server/warehouse/internal/model" @@ -18,22 +21,17 @@ import ( "github.com/lib/pq" "golang.org/x/exp/slices" - "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-server/warehouse/logfield" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) -type loadTableResponse struct { - StagingTableName string -} - type loadUsersTableResponse struct { identifiesError error usersError error } -func (pg *Postgres) LoadTable(ctx context.Context, tableName string) error { - pg.logger.Infow("started loading", +func (pg *Postgres) LoadTable(ctx context.Context, tableName string) (*types.LoadTableStats, error) { + log := pg.logger.With( logfield.SourceID, pg.Warehouse.Source.ID, logfield.SourceType, pg.Warehouse.Source.SourceDefinition.Name, logfield.DestinationID, pg.Warehouse.Destination.ID, @@ -41,19 +39,38 @@ func (pg *Postgres) LoadTable(ctx context.Context, tableName string) error { logfield.WorkspaceID, pg.Warehouse.WorkspaceID, logfield.Namespace, pg.Namespace, logfield.TableName, tableName, + logfield.LoadTableStrategy, pg.loadTableStrategy(), ) + log.Infow("started loading") - err := pg.DB.WithTx(ctx, func(tx *sqlmiddleware.Tx) error { - tableSchemaInUpload := pg.Uploader.GetTableSchemaInUpload(tableName) + var loadTableStats *types.LoadTableStats + var err error - _, err := pg.loadTable(ctx, tx, tableName, tableSchemaInUpload) + err = pg.DB.WithTx(ctx, func(tx *sqlmiddleware.Tx) error { + loadTableStats, _, err = pg.loadTable( + ctx, + tx, + tableName, + pg.Uploader.GetTableSchemaInUpload(tableName), + ) return err }) if err != nil { - return fmt.Errorf("loading table: %w", err) + return nil, fmt.Errorf("loading table: %w", err) } - pg.logger.Infow("completed loading", + log.Infow("completed loading") + + return loadTableStats, err +} + +func (pg *Postgres) loadTable( + ctx context.Context, + txn *sqlmiddleware.Tx, + tableName string, + tableSchemaInUpload model.TableSchema, +) (*types.LoadTableStats, string, error) { + log := pg.logger.With( logfield.SourceID, pg.Warehouse.Source.ID, logfield.SourceType, pg.Warehouse.Source.SourceDefinition.Name, logfield.DestinationID, pg.Warehouse.Destination.ID, @@ -61,26 +78,34 @@ func (pg *Postgres) LoadTable(ctx context.Context, tableName string) error { logfield.WorkspaceID, pg.Warehouse.WorkspaceID, logfield.Namespace, pg.Namespace, logfield.TableName, tableName, + logfield.LoadTableStrategy, pg.loadTableStrategy(), ) + log.Infow("started loading") - return nil -} + log.Debugw("setting search path") + searchPathStmt := fmt.Sprintf(`SET search_path TO %q;`, + pg.Namespace, + ) + if _, err := txn.ExecContext(ctx, searchPathStmt); err != nil { + return nil, "", fmt.Errorf("setting search path: %w", err) + } -func (pg *Postgres) loadTable( - ctx context.Context, - txn *sqlmiddleware.Tx, - tableName string, - tableSchemaInUpload model.TableSchema, -) (loadTableResponse, error) { - query := fmt.Sprintf(`SET search_path TO %q;`, pg.Namespace) - if _, err := txn.ExecContext(ctx, query); err != nil { - return loadTableResponse{}, fmt.Errorf("setting search path: %w", err) + loadFiles, err := pg.LoadFileDownloader.Download(ctx, tableName) + if err != nil { + return nil, "", fmt.Errorf("downloading load files: %w", err) } + defer func() { + misc.RemoveFilePaths(loadFiles...) + }() - // Creating staging table - sortedColumnKeys := warehouseutils.SortColumnKeysFromColumnMap(tableSchemaInUpload) - stagingTableName := warehouseutils.StagingTableName(provider, tableName, tableNameLimit) - query = fmt.Sprintf(` + stagingTableName := warehouseutils.StagingTableName( + provider, + tableName, + tableNameLimit, + ) + + log.Debugw("creating staging table") + createStagingTableStmt := fmt.Sprintf(` CREATE TEMPORARY TABLE %[2]s (LIKE %[1]q.%[3]q) ON COMMIT PRESERVE ROWS; `, @@ -88,102 +113,137 @@ func (pg *Postgres) loadTable( stagingTableName, tableName, ) - pg.logger.Infow("creating temporary table", - logfield.SourceID, pg.Warehouse.Source.ID, - logfield.SourceType, pg.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, pg.Warehouse.Destination.ID, - logfield.DestinationType, pg.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, pg.Warehouse.WorkspaceID, - logfield.Namespace, pg.Namespace, - logfield.TableName, tableName, - logfield.StagingTableName, stagingTableName, - logfield.Query, query, - ) - if _, err := txn.ExecContext(ctx, query); err != nil { - return loadTableResponse{}, fmt.Errorf("creating temporary table: %w", err) + if _, err := txn.ExecContext(ctx, createStagingTableStmt); err != nil { + return nil, "", fmt.Errorf("creating temporary table: %w", err) } - stmt, err := txn.PrepareContext(ctx, pq.CopyIn(stagingTableName, sortedColumnKeys...)) - if err != nil { - return loadTableResponse{}, fmt.Errorf("preparing statement for copy in: %w", err) - } + sortedColumnKeys := warehouseutils.SortColumnKeysFromColumnMap( + tableSchemaInUpload, + ) - loadFiles, err := pg.LoadFileDownloader.Download(ctx, tableName) - defer misc.RemoveFilePaths(loadFiles...) + log.Debugw("creating prepared stmt for loading data") + copyInStmt := pq.CopyIn(stagingTableName, sortedColumnKeys...) + stmt, err := txn.PrepareContext(ctx, copyInStmt) if err != nil { - return loadTableResponse{}, fmt.Errorf("downloading load files: %w", err) + return nil, "", fmt.Errorf("preparing statement for copy in: %w", err) } - var csvRowsProcessedCount int64 - for _, objectFileName := range loadFiles { - gzFile, err := os.Open(objectFileName) + log.Infow("loading data into staging table") + for _, fileName := range loadFiles { + err = pg.loadDataIntoStagingTable( + ctx, stmt, + fileName, sortedColumnKeys, + ) if err != nil { - return loadTableResponse{}, fmt.Errorf("opening load file: %w", err) + return nil, "", fmt.Errorf("loading data into staging table: %w", err) } + } + if _, err = stmt.ExecContext(ctx); err != nil { + return nil, "", fmt.Errorf("executing copyIn statement: %w", err) + } - gzReader, err := gzip.NewReader(gzFile) + var rowsDeleted int64 + if !slices.Contains(pg.config.skipDedupDestinationIDs, pg.Warehouse.Destination.ID) { + log.Infow("deleting from load table") + rowsDeleted, err = pg.deleteFromLoadTable( + ctx, txn, tableName, + stagingTableName, + ) if err != nil { - _ = gzFile.Close() - - return loadTableResponse{}, fmt.Errorf("reading gzip load file: %w", err) + return nil, "", fmt.Errorf("delete from load table: %w", err) } + } - csvReader := csv.NewReader(gzReader) + log.Infow("inserting into load table") + rowsInserted, err := pg.insertIntoLoadTable( + ctx, txn, tableName, + stagingTableName, sortedColumnKeys, + ) + if err != nil { + return nil, "", fmt.Errorf("insert into: %w", err) + } - for { - var ( - record []string - recordInterface []interface{} - ) + return &types.LoadTableStats{ + RowsInserted: rowsInserted - rowsDeleted, + RowsUpdated: rowsDeleted, + }, stagingTableName, nil +} - record, err := csvReader.Read() - if err != nil { - if err == io.EOF { - break - } +func (pg *Postgres) loadDataIntoStagingTable( + ctx context.Context, + stmt *sql.Stmt, + fileName string, + sortedColumnKeys []string, +) error { + gzipFile, err := os.Open(fileName) + if err != nil { + return fmt.Errorf("opening load file: %w", err) + } + defer func() { + _ = gzipFile.Close() + }() - return loadTableResponse{}, fmt.Errorf("reading csv file: %w", err) - } + gzReader, err := gzip.NewReader(gzipFile) + if err != nil { + return fmt.Errorf("reading gzip load file: %w", err) + } + defer func() { + _ = gzReader.Close() + }() - if len(sortedColumnKeys) != len(record) { - return loadTableResponse{}, fmt.Errorf("missing columns in csv file %s", objectFileName) - } + csvReader := csv.NewReader(gzReader) - for _, value := range record { - if strings.TrimSpace(value) == "" { - recordInterface = append(recordInterface, nil) - } else { - recordInterface = append(recordInterface, value) - } + for { + record, err := csvReader.Read() + if err != nil { + if errors.Is(err, io.EOF) { + break } + return fmt.Errorf("reading file: %w", err) + } + if len(sortedColumnKeys) != len(record) { + return fmt.Errorf("mismatch in number of columns: actual count: %d, expected count: %d", + len(record), + len(sortedColumnKeys), + ) + } - _, err = stmt.ExecContext(ctx, recordInterface...) - if err != nil { - return loadTableResponse{}, fmt.Errorf("exec statement: %w", err) + recordInterface := make([]interface{}, 0, len(record)) + for _, value := range record { + if strings.TrimSpace(value) == "" { + recordInterface = append(recordInterface, nil) + } else { + recordInterface = append(recordInterface, value) } - - csvRowsProcessedCount++ } - _ = gzReader.Close() - _ = gzFile.Close() - } - if _, err = stmt.ExecContext(ctx); err != nil { - return loadTableResponse{}, fmt.Errorf("exec statement: %w", err) + _, err = stmt.ExecContext(ctx, recordInterface...) + if err != nil { + return fmt.Errorf("exec statement: %w", err) + } } + return nil +} - var ( - primaryKey = "id" - partitionKey = "id" +func (pg *Postgres) loadTableStrategy() string { + if slices.Contains(pg.config.skipDedupDestinationIDs, pg.Warehouse.Destination.ID) { + return "APPEND" + } + return "MERGE" +} - additionalJoinClause string - ) +func (pg *Postgres) deleteFromLoadTable( + ctx context.Context, + txn *sqlmiddleware.Tx, + tableName string, + stagingTableName string, +) (int64, error) { + primaryKey := "id" if column, ok := primaryKeyMap[tableName]; ok { primaryKey = column } - if column, ok := partitionKeyMap[tableName]; ok { - partitionKey = column - } + + var additionalJoinClause string if tableName == warehouseutils.DiscardsTable { additionalJoinClause = fmt.Sprintf( `AND _source.%[3]s = %[1]q.%[2]q.%[3]q AND _source.%[4]s = %[1]q.%[2]q.%[4]q`, @@ -194,9 +254,7 @@ func (pg *Postgres) loadTable( ) } - // Deduplication - // Delete rows from the table which are already present in the staging table - query = fmt.Sprintf(` + deleteStmt := fmt.Sprintf(` DELETE FROM %[1]q.%[2]q USING %[3]q AS _source WHERE @@ -211,44 +269,30 @@ func (pg *Postgres) loadTable( additionalJoinClause, ) - if !slices.Contains(pg.config.skipDedupDestinationIDs, pg.Warehouse.Destination.ID) { - pg.logger.Infow("deduplication", - logfield.SourceID, pg.Warehouse.Source.ID, - logfield.SourceType, pg.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, pg.Warehouse.Destination.ID, - logfield.DestinationType, pg.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, pg.Warehouse.WorkspaceID, - logfield.Namespace, pg.Namespace, - logfield.TableName, tableName, - logfield.StagingTableName, stagingTableName, - logfield.Query, query, - ) - - result, err := txn.ExecContext(ctx, query) - if err != nil { - return loadTableResponse{}, fmt.Errorf("deleting from original table for dedup: %w", err) - } - - rowsAffected, err := result.RowsAffected() - if err != nil { - return loadTableResponse{}, fmt.Errorf("getting rows affected for dedup: %w", err) - } + result, err := txn.ExecContext(ctx, deleteStmt) + if err != nil { + return 0, fmt.Errorf("deleting from main table for dedup: %w", err) + } + return result.RowsAffected() +} - pg.stats.NewTaggedStat("dedup_rows", stats.CountType, stats.Tags{ - "sourceID": pg.Warehouse.Source.ID, - "sourceType": pg.Warehouse.Source.SourceDefinition.Name, - "sourceCategory": pg.Warehouse.Source.SourceDefinition.Category, - "destID": pg.Warehouse.Destination.ID, - "destType": pg.Warehouse.Destination.DestinationDefinition.Name, - "workspaceId": pg.Warehouse.WorkspaceID, - "tableName": tableName, - "rowsAffected": fmt.Sprintf("%d", rowsAffected), - }) +func (pg *Postgres) insertIntoLoadTable( + ctx context.Context, + txn *sqlmiddleware.Tx, + tableName string, + stagingTableName string, + sortedColumnKeys []string, +) (int64, error) { + partitionKey := "id" + if column, ok := partitionKeyMap[tableName]; ok { + partitionKey = column } - // Insert rows from staging table to the original table - quotedColumnNames := warehouseutils.DoubleQuoteAndJoinByComma(sortedColumnKeys) - query = fmt.Sprintf(` + quotedColumnNames := warehouseutils.DoubleQuoteAndJoinByComma( + sortedColumnKeys, + ) + + insertStmt := fmt.Sprintf(` INSERT INTO %[1]q.%[2]q (%[3]s) SELECT %[3]s @@ -274,25 +318,11 @@ func (pg *Postgres) loadTable( partitionKey, ) - pg.logger.Infow("inserting records", - logfield.SourceID, pg.Warehouse.Source.ID, - logfield.SourceType, pg.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, pg.Warehouse.Destination.ID, - logfield.DestinationType, pg.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, pg.Warehouse.WorkspaceID, - logfield.Namespace, pg.Namespace, - logfield.TableName, tableName, - logfield.StagingTableName, stagingTableName, - logfield.Query, query, - ) - if _, err := txn.ExecContext(ctx, query); err != nil { - return loadTableResponse{}, fmt.Errorf("executing query: %w", err) - } - - response := loadTableResponse{ - StagingTableName: stagingTableName, + r, err := txn.ExecContext(ctx, insertStmt) + if err != nil { + return 0, fmt.Errorf("inserting into main table: %w", err) } - return response, nil + return r.RowsAffected() } func (pg *Postgres) LoadUserTables(ctx context.Context) map[string]error { @@ -357,7 +387,7 @@ func (pg *Postgres) loadUsersTable( usersSchemaInUpload, usersSchemaInWarehouse model.TableSchema, ) loadUsersTableResponse { - identifiesTableResponse, err := pg.loadTable(ctx, tx, warehouseutils.IdentifiesTable, identifiesSchemaInUpload) + _, identifyStagingTable, err := pg.loadTable(ctx, tx, warehouseutils.IdentifiesTable, identifiesSchemaInUpload) if err != nil { return loadUsersTableResponse{ identifiesError: fmt.Errorf("loading identifies table: %w", err), @@ -370,7 +400,7 @@ func (pg *Postgres) loadUsersTable( canSkipComputingLatestUserTraits := pg.config.skipComputingUserLatestTraits || slices.Contains(pg.config.skipComputingUserLatestTraitsWorkspaceIDs, pg.Warehouse.WorkspaceID) if canSkipComputingLatestUserTraits { - if _, err = pg.loadTable(ctx, tx, warehouseutils.UsersTable, usersSchemaInUpload); err != nil { + if _, _, err = pg.loadTable(ctx, tx, warehouseutils.UsersTable, usersSchemaInUpload); err != nil { return loadUsersTableResponse{ usersError: fmt.Errorf("loading users table: %w", err), } @@ -443,7 +473,7 @@ func (pg *Postgres) loadUsersTable( `, pg.Namespace, warehouseutils.UsersTable, - identifiesTableResponse.StagingTableName, + identifyStagingTable, strings.Join(userColNames, ","), unionStagingTableName, ) diff --git a/warehouse/integrations/postgres/load_test.go b/warehouse/integrations/postgres/load_test.go index a1d4246fff..71e22197e4 100644 --- a/warehouse/integrations/postgres/load_test.go +++ b/warehouse/integrations/postgres/load_test.go @@ -1,4 +1,4 @@ -package postgres +package postgres_test import ( "context" @@ -10,6 +10,8 @@ import ( "github.com/golang/mock/gomock" + "github.com/rudderlabs/rudder-server/warehouse/integrations/postgres" + mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" "github.com/rudderlabs/rudder-server/warehouse/internal/model" @@ -58,304 +60,6 @@ func cloneFiles(t *testing.T, files []string) []string { return tempFiles } -func TestLoadTable(t *testing.T) { - t.Parallel() - - misc.Init() - warehouseutils.Init() - - pool, err := dockertest.NewPool("") - require.NoError(t, err) - - const ( - namespace = "test_namespace" - sourceID = "test_source_id" - destID = "test_dest_id" - sourceType = "test_source_type" - destType = "test_dest_type" - workspaceID = "test_workspace_id" - ) - - warehouse := model.Warehouse{ - Source: backendconfig.SourceT{ - ID: sourceID, - SourceDefinition: backendconfig.SourceDefinitionT{ - Name: sourceType, - }, - }, - Destination: backendconfig.DestinationT{ - ID: destID, - DestinationDefinition: backendconfig.DestinationDefinitionT{ - Name: destType, - }, - }, - WorkspaceID: workspaceID, - } - - t.Run("Regular tables", func(t *testing.T) { - t.Parallel() - - tableName := "test_table" - - testCases := []struct { - name string - wantError error - mockError error - skipSchemaCreation bool - skipTableCreation bool - cancelContext bool - mockFiles []string - additionalFiles []string - queryExecEnabledWorkspaceIDs []string - }{ - { - name: "schema not present", - skipSchemaCreation: true, - mockFiles: []string{"load.csv.gz"}, - wantError: errors.New("loading table: executing transaction: creating temporary table: pq: schema \"test_namespace\" does not exist"), - }, - { - name: "table not present", - skipTableCreation: true, - mockFiles: []string{"load.csv.gz"}, - wantError: errors.New("loading table: executing transaction: creating temporary table: pq: relation \"test_namespace.test_table\" does not exist"), - }, - { - name: "download error", - mockFiles: []string{"load.csv.gz"}, - mockError: errors.New("test error"), - wantError: errors.New("loading table: executing transaction: downloading load files: test error"), - }, - { - name: "load file not present", - additionalFiles: []string{"testdata/random.csv.gz"}, - wantError: errors.New("loading table: executing transaction: opening load file: open testdata/random.csv.gz: no such file or directory"), - }, - { - name: "less records than expected", - mockFiles: []string{"less-records.csv.gz"}, - wantError: errors.New("loading table: executing transaction: missing columns in csv file"), - }, - { - name: "bad records", - mockFiles: []string{"bad.csv.gz"}, - wantError: errors.New("loading table: executing transaction: exec statement: pq: invalid input syntax for type timestamp: \"1\""), - }, - { - name: "success", - mockFiles: []string{"load.csv.gz"}, - }, - { - name: "enable query execution", - mockFiles: []string{"load.csv.gz"}, - queryExecEnabledWorkspaceIDs: []string{workspaceID}, - }, - { - name: "context cancelled", - mockFiles: []string{"load.csv.gz"}, - wantError: errors.New("loading table: begin transaction: context canceled"), - cancelContext: true, - }, - } - - for _, tc := range testCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - pgResource, err := resource.SetupPostgres(pool, t) - require.NoError(t, err) - - t.Log("db:", pgResource.DBDsn) - - db := sqlmiddleware.New(pgResource.DB) - - store := memstats.New() - c := config.New() - c.Set("Warehouse.postgres.EnableSQLStatementExecutionPlanWorkspaceIDs", tc.queryExecEnabledWorkspaceIDs) - - ctx, cancel := context.WithCancel(context.Background()) - if tc.cancelContext { - cancel() - } else { - defer cancel() - } - - if !tc.skipSchemaCreation { - _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS " + namespace) - require.NoError(t, err) - } - if !tc.skipTableCreation && !tc.skipSchemaCreation { - _, err = db.Exec(fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s.%s ( - test_bool boolean, - test_datetime timestamp, - test_float float, - test_int int, - test_string varchar(255), - id varchar(255), - received_at timestamptz - ) - `, - namespace, - tableName, - )) - require.NoError(t, err) - } - - loadFiles := cloneFiles(t, tc.mockFiles) - loadFiles = append(loadFiles, tc.additionalFiles...) - require.NotEmpty(t, loadFiles) - - pg := New(c, logger.NOP, store) - pg.DB = db - pg.Namespace = namespace - pg.Warehouse = warehouse - pg.stats = store - pg.LoadFileDownloader = &mockLoadFileUploader{ - mockFiles: map[string][]string{ - tableName: loadFiles, - }, - mockError: map[string]error{ - tableName: tc.mockError, - }, - } - pg.Uploader = newMockUploader(t, map[string]model.TableSchema{ - tableName: { - "test_bool": "boolean", - "test_datetime": "datetime", - "test_float": "float", - "test_int": "int", - "test_string": "string", - "id": "string", - "received_at": "datetime", - }, - }) - - err = pg.LoadTable(ctx, tableName) - if tc.wantError != nil { - require.ErrorContains(t, err, tc.wantError.Error()) - return - } - require.NoError(t, err) - }) - } - }) - - t.Run("Discards tables", func(t *testing.T) { - t.Parallel() - - tableName := warehouseutils.DiscardsTable - - testCases := []struct { - name string - wantError error - mockError error - skipSchemaCreation bool - skipTableCreation bool - cancelContext bool - mockFiles []string - }{ - { - name: "schema not present", - skipSchemaCreation: true, - mockFiles: []string{"discards.csv.gz"}, - wantError: errors.New("loading table: executing transaction: creating temporary table: pq: schema \"test_namespace\" does not exist"), - }, - { - name: "table not present", - skipTableCreation: true, - mockFiles: []string{"discards.csv.gz"}, - wantError: errors.New("loading table: executing transaction: creating temporary table: pq: relation \"test_namespace.rudder_discards\" does not exist"), - }, - { - name: "download error", - mockFiles: []string{"discards.csv.gz"}, - wantError: errors.New("loading table: executing transaction: downloading load files: test error"), - mockError: errors.New("test error"), - }, - { - name: "success", - mockFiles: []string{"discards.csv.gz"}, - }, - } - - for _, tc := range testCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - pgResource, err := resource.SetupPostgres(pool, t) - require.NoError(t, err) - - t.Log("db:", pgResource.DBDsn) - - db := sqlmiddleware.New(pgResource.DB) - - store := memstats.New() - c := config.New() - - ctx, cancel := context.WithCancel(context.Background()) - if tc.cancelContext { - cancel() - } else { - defer cancel() - } - - if !tc.skipSchemaCreation { - _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS " + namespace) - require.NoError(t, err) - } - if !tc.skipTableCreation && !tc.skipSchemaCreation { - _, err = db.Exec(fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s.%s ( - "column_name" "varchar", - "column_value" "varchar", - "received_at" "timestamptz", - "row_id" "varchar", - "table_name" "varchar", - "uuid_ts" "timestamptz" - ) - `, - namespace, - tableName, - )) - require.NoError(t, err) - } - - loadFiles := cloneFiles(t, tc.mockFiles) - require.NotEmpty(t, loadFiles) - - pg := New(c, logger.NOP, store) - pg.DB = db - pg.Namespace = namespace - pg.Warehouse = warehouse - pg.stats = store - pg.LoadFileDownloader = &mockLoadFileUploader{ - mockFiles: map[string][]string{ - tableName: loadFiles, - }, - mockError: map[string]error{ - tableName: tc.mockError, - }, - } - pg.Uploader = newMockUploader(t, map[string]model.TableSchema{ - tableName: warehouseutils.DiscardsSchema, - }) - - err = pg.LoadTable(ctx, tableName) - if tc.wantError != nil { - require.EqualError(t, err, tc.wantError.Error()) - return - } - require.NoError(t, err) - }) - } - }) -} - func TestLoadUsersTable(t *testing.T) { t.Parallel() @@ -501,7 +205,7 @@ func TestLoadUsersTable(t *testing.T) { identifiesLoadFiles := cloneFiles(t, tc.mockIdentifiesFiles) require.NotEmpty(t, identifiesLoadFiles) - pg := New(c, logger.NOP, store) + pg := postgres.New(c, logger.NOP, store) var ( identifiesSchemaInUpload = model.TableSchema{ @@ -530,10 +234,23 @@ func TestLoadUsersTable(t *testing.T) { usersSchamaInUpload = tc.usersSchemaInUpload } + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + schema := map[string]model.TableSchema{ + warehouseutils.UsersTable: usersSchamaInUpload, + warehouseutils.IdentifiesTable: identifiesSchemaInUpload, + } + f := func(tableName string) model.TableSchema { + return schema[tableName] + } + mockUploader := mockuploader.NewMockUploader(ctrl) + mockUploader.EXPECT().GetTableSchemaInUpload(gomock.Any()).AnyTimes().DoAndReturn(f) + mockUploader.EXPECT().GetTableSchemaInWarehouse(gomock.Any()).AnyTimes().DoAndReturn(f) + pg.DB = db pg.Namespace = namespace pg.Warehouse = warehouse - pg.stats = store pg.LoadFileDownloader = &mockLoadFileUploader{ mockFiles: map[string][]string{ warehouseutils.UsersTable: usersLoadFiles, @@ -544,10 +261,7 @@ func TestLoadUsersTable(t *testing.T) { warehouseutils.IdentifiesTable: tc.mockIdentifiesError, }, } - pg.Uploader = newMockUploader(t, map[string]model.TableSchema{ - warehouseutils.UsersTable: usersSchamaInUpload, - warehouseutils.IdentifiesTable: identifiesSchemaInUpload, - }) + pg.Uploader = mockUploader errorsMap := pg.LoadUserTables(ctx) require.NotEmpty(t, errorsMap) @@ -564,12 +278,3 @@ func TestLoadUsersTable(t *testing.T) { }) } } - -func newMockUploader(t testing.TB, schema model.Schema) *mockuploader.MockUploader { - ctrl := gomock.NewController(t) - f := func(tableName string) model.TableSchema { return schema[tableName] } - u := mockuploader.NewMockUploader(ctrl) - u.EXPECT().GetTableSchemaInUpload(gomock.Any()).AnyTimes().DoAndReturn(f) - u.EXPECT().GetTableSchemaInWarehouse(gomock.Any()).AnyTimes().DoAndReturn(f) - return u -} diff --git a/warehouse/integrations/postgres/postgres.go b/warehouse/integrations/postgres/postgres.go index 312302e31b..ba8eec3447 100644 --- a/warehouse/integrations/postgres/postgres.go +++ b/warehouse/integrations/postgres/postgres.go @@ -490,22 +490,6 @@ func (*Postgres) DownloadIdentityRules(context.Context, *misc.GZipWriter) (err e return } -func (pg *Postgres) GetTotalCountInTable(ctx context.Context, tableName string) (int64, error) { - var ( - total int64 - err error - sqlStatement string - ) - sqlStatement = fmt.Sprintf(` - SELECT count(*) FROM "%[1]s"."%[2]s"; - `, - pg.Namespace, - tableName, - ) - err = pg.DB.QueryRowContext(ctx, sqlStatement).Scan(&total) - return total, err -} - func (pg *Postgres) Connect(_ context.Context, warehouse model.Warehouse) (client.Client, error) { if warehouse.Destination.Config["sslMode"] == "verify-ca" { if err := warehouseutils.WriteSSLKeys(warehouse.Destination); err.IsError() { diff --git a/warehouse/integrations/postgres/postgres_test.go b/warehouse/integrations/postgres/postgres_test.go index 4dae8caf9b..3b41bf6163 100644 --- a/warehouse/integrations/postgres/postgres_test.go +++ b/warehouse/integrations/postgres/postgres_test.go @@ -10,6 +10,16 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-server/warehouse/integrations/postgres" + mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/rudderlabs/compose-test/compose" "github.com/rudderlabs/rudder-server/testhelper/workspaceConfig" @@ -39,7 +49,8 @@ func TestIntegration(t *testing.T) { } c := testcompose.New(t, compose.FilePaths([]string{ - "testdata/docker-compose.yml", + "testdata/docker-compose.postgres.yml", + "testdata/docker-compose.ssh-server.yml", "../testdata/docker-compose.jobsdb.yml", "../testdata/docker-compose.minio.yml", })) @@ -92,6 +103,7 @@ func TestIntegration(t *testing.T) { bucketName := "testbucket" accessKeyID := "MYACCESSKEY" secretAccessKey := "MYSECRETKEY" + region := "us-east-1" minioEndpoint := fmt.Sprintf("localhost:%d", minioPort) @@ -137,13 +149,8 @@ func TestIntegration(t *testing.T) { t.Setenv("MINIO_SECRET_ACCESS_KEY", secretAccessKey) t.Setenv("MINIO_MINIO_ENDPOINT", minioEndpoint) t.Setenv("MINIO_SSL", "false") - t.Setenv("RSERVER_WAREHOUSE_POSTGRES_MAX_PARALLEL_LOADS", "8") - t.Setenv("RSERVER_WAREHOUSE_POSTGRES_SKIP_COMPUTING_USER_LATEST_TRAITS_WORKSPACE_IDS", workspaceID) - t.Setenv("RSERVER_WAREHOUSE_POSTGRES_ENABLE_SQLSTATEMENT_EXECUTION_PLAN_WORKSPACE_IDS", workspaceID) - t.Setenv("RSERVER_WAREHOUSE_POSTGRES_ENABLE_DELETE_BY_JOBS", "true") t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) - t.Setenv("RSERVER_WAREHOUSE_POSTGRES_SLOW_QUERY_THRESHOLD", "0s") svcDone := make(chan struct{}) @@ -162,6 +169,12 @@ func TestIntegration(t *testing.T) { health.WaitUntilReady(ctx, t, serviceHealthEndpoint, time.Minute, time.Second, "serviceHealthEndpoint") t.Run("Events flow", func(t *testing.T) { + t.Setenv("RSERVER_WAREHOUSE_POSTGRES_MAX_PARALLEL_LOADS", "8") + t.Setenv("RSERVER_WAREHOUSE_POSTGRES_SKIP_COMPUTING_USER_LATEST_TRAITS_WORKSPACE_IDS", workspaceID) + t.Setenv("RSERVER_WAREHOUSE_POSTGRES_ENABLE_SQLSTATEMENT_EXECUTION_PLAN_WORKSPACE_IDS", workspaceID) + t.Setenv("RSERVER_WAREHOUSE_POSTGRES_ENABLE_DELETE_BY_JOBS", "true") + t.Setenv("RSERVER_WAREHOUSE_POSTGRES_ENABLE_DELETE_BY_JOBS", "true") + dsn := fmt.Sprintf( "postgres://%s:%s@%s:%s/%s?sslmode=disable", "rudder", "rudder-password", "localhost", strconv.Itoa(postgresPort), "rudderdb", @@ -290,6 +303,12 @@ func TestIntegration(t *testing.T) { }) t.Run("Events flow with ssh tunnel", func(t *testing.T) { + t.Setenv("RSERVER_WAREHOUSE_POSTGRES_MAX_PARALLEL_LOADS", "8") + t.Setenv("RSERVER_WAREHOUSE_POSTGRES_SKIP_COMPUTING_USER_LATEST_TRAITS_WORKSPACE_IDS", workspaceID) + t.Setenv("RSERVER_WAREHOUSE_POSTGRES_ENABLE_SQLSTATEMENT_EXECUTION_PLAN_WORKSPACE_IDS", workspaceID) + t.Setenv("RSERVER_WAREHOUSE_POSTGRES_ENABLE_DELETE_BY_JOBS", "true") + t.Setenv("RSERVER_WAREHOUSE_POSTGRES_ENABLE_DELETE_BY_JOBS", "true") + dsn := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", tunnelledUser, tunnelledPassword, @@ -441,4 +460,400 @@ func TestIntegration(t *testing.T) { } testhelper.VerifyConfigurationTest(t, dest) }) + + t.Run("Load Table", func(t *testing.T) { + const ( + sourceID = "test_source_id" + destinationID = "test_destination_id" + workspaceID = "test_workspace_id" + ) + + namespace := testhelper.RandSchema(destType) + + schemaInUpload := model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + } + schemaInWarehouse := model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + "extra_test_bool": "boolean", + "extra_test_datetime": "datetime", + "extra_test_float": "float", + "extra_test_int": "int", + "extra_test_string": "string", + } + + warehouse := model.Warehouse{ + Source: backendconfig.SourceT{ + ID: sourceID, + }, + Destination: backendconfig.DestinationT{ + ID: destinationID, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: destType, + }, + Config: map[string]any{ + "host": host, + "database": database, + "user": user, + "password": password, + "port": strconv.Itoa(postgresPort), + "sslMode": "disable", + "namespace": "", + "bucketProvider": "MINIO", + "bucketName": bucketName, + "accessKeyID": accessKeyID, + "secretAccessKey": secretAccessKey, + "useSSL": false, + "endPoint": minioEndpoint, + "syncFrequency": "30", + "useRudderStorage": false, + }, + }, + WorkspaceID: workspaceID, + Namespace: namespace, + } + + fm, err := filemanager.New(&filemanager.Settings{ + Provider: warehouseutils.MINIO, + Config: map[string]any{ + "bucketName": bucketName, + "accessKeyID": accessKeyID, + "secretAccessKey": secretAccessKey, + "endPoint": minioEndpoint, + "forcePathStyle": true, + "s3ForcePathStyle": true, + "disableSSL": true, + "region": region, + "enableSSE": false, + "bucketProvider": warehouseutils.MINIO, + }, + }) + require.NoError(t, err) + + t.Run("schema does not exists", func(t *testing.T) { + tableName := "schema_not_exists_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + pg := postgres.New(config.Default, logger.NOP, stats.Default) + err := pg.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + loadTableStat, err := pg.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("table does not exists", func(t *testing.T) { + tableName := "table_not_exists_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + pg := postgres.New(config.Default, logger.NOP, stats.Default) + err := pg.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = pg.CreateSchema(ctx) + require.NoError(t, err) + + loadTableStat, err := pg.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("merge", func(t *testing.T) { + tableName := "merge_test_table" + + t.Run("without dedup", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + c := config.New() + c.Set("Warehouse.postgres.EnableSQLStatementExecutionPlanWorkspaceIDs", workspaceID) + + pg := postgres.New(c, logger.NOP, stats.Default) + err := pg.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = pg.CreateSchema(ctx) + require.NoError(t, err) + + err = pg.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := pg.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = pg.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, pg.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + t.Run("with dedup", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + c := config.New() + c.Set("Warehouse.postgres.EnableSQLStatementExecutionPlanWorkspaceIDs", workspaceID) + + pg := postgres.New(config.Default, logger.NOP, stats.Default) + err := pg.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = pg.CreateSchema(ctx) + require.NoError(t, err) + + err = pg.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := pg.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, pg.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DedupTestRecords()) + }) + }) + t.Run("append", func(t *testing.T) { + tableName := "append_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + c := config.New() + c.Set("Warehouse.postgres.skipDedupDestinationIDs", destinationID) + + pg := postgres.New(c, logger.NOP, stats.Default) + err := pg.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = pg.CreateSchema(ctx) + require.NoError(t, err) + + err = pg.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := pg.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = pg.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, pg.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.AppendTestRecords()) + }) + t.Run("load file does not exists", func(t *testing.T) { + tableName := "load_file_not_exists_test_table" + + loadFiles := []warehouseutils.LoadFile{{ + Location: "http://localhost:1234/testbucket/rudder-warehouse-load-objects/load_file_not_exists_test_table/test_source_id/f31af97e-03e8-46d0-8a1a-1786cb85b22c-load_file_not_exists_test_table/load.csv.gz", + }} + mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + pg := postgres.New(config.Default, logger.NOP, stats.Default) + err := pg.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = pg.CreateSchema(ctx) + require.NoError(t, err) + + err = pg.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := pg.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("mismatch in number of columns", func(t *testing.T) { + tableName := "mismatch_columns_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + pg := postgres.New(config.Default, logger.NOP, stats.Default) + err := pg.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = pg.CreateSchema(ctx) + require.NoError(t, err) + + err = pg.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := pg.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("mismatch in schema", func(t *testing.T) { + tableName := "mismatch_schema_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := mockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + + pg := postgres.New(config.Default, logger.NOP, stats.Default) + err := pg.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = pg.CreateSchema(ctx) + require.NoError(t, err) + + err = pg.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := pg.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("discards", func(t *testing.T) { + tableName := warehouseutils.DiscardsTable + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := mockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema) + + pg := postgres.New(config.Default, logger.NOP, stats.Default) + err := pg.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = pg.CreateSchema(ctx) + require.NoError(t, err) + + err = pg.CreateTable(ctx, tableName, warehouseutils.DiscardsSchema) + require.NoError(t, err) + + loadTableStat, err := pg.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(6)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, pg.DB.DB, + fmt.Sprintf(` + SELECT + column_name, + column_value, + received_at, + row_id, + table_name, + uuid_ts + FROM + %q.%q + ORDER BY row_id ASC; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DiscardTestRecords()) + }) + }) +} + +func mockUploader( + t testing.TB, + loadFiles []warehouseutils.LoadFile, + tableName string, + schemaInUpload model.TableSchema, + schemaInWarehouse model.TableSchema, +) warehouseutils.Uploader { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockUploader := mockuploader.NewMockUploader(ctrl) + mockUploader.EXPECT().UseRudderStorage().Return(false).AnyTimes() + mockUploader.EXPECT().GetLoadFilesMetadata(gomock.Any(), gomock.Any()).Return(loadFiles).AnyTimes() // Try removing this + mockUploader.EXPECT().GetTableSchemaInUpload(tableName).Return(schemaInUpload).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInWarehouse(tableName).Return(schemaInWarehouse).AnyTimes() + + return mockUploader } diff --git a/warehouse/integrations/postgres/testdata/bad.csv.gz b/warehouse/integrations/postgres/testdata/bad.csv.gz deleted file mode 100644 index bae478223a..0000000000 Binary files a/warehouse/integrations/postgres/testdata/bad.csv.gz and /dev/null differ diff --git a/warehouse/integrations/postgres/testdata/docker-compose.postgres.yml b/warehouse/integrations/postgres/testdata/docker-compose.postgres.yml new file mode 100644 index 0000000000..732e3c2809 --- /dev/null +++ b/warehouse/integrations/postgres/testdata/docker-compose.postgres.yml @@ -0,0 +1,15 @@ +version: "3.9" + +services: + postgres: + image: postgres:15-alpine + environment: + - POSTGRES_DB=rudderdb + - POSTGRES_PASSWORD=rudder-password + - POSTGRES_USER=rudder + ports: + - "5432" + healthcheck: + test: [ "CMD-SHELL", "pg_isready" ] + interval: 1s + retries: 25 diff --git a/warehouse/integrations/postgres/testdata/docker-compose.yml b/warehouse/integrations/postgres/testdata/docker-compose.ssh-server.yml similarity index 77% rename from warehouse/integrations/postgres/testdata/docker-compose.yml rename to warehouse/integrations/postgres/testdata/docker-compose.ssh-server.yml index 105ed005e8..ffe6ec71cf 100644 --- a/warehouse/integrations/postgres/testdata/docker-compose.yml +++ b/warehouse/integrations/postgres/testdata/docker-compose.ssh-server.yml @@ -31,15 +31,3 @@ services: interval: 1s timeout: 5s retries: 30 - postgres: - image: postgres:15-alpine - environment: - - POSTGRES_DB=rudderdb - - POSTGRES_PASSWORD=rudder-password - - POSTGRES_USER=rudder - ports: - - "5432" - healthcheck: - test: [ "CMD-SHELL", "pg_isready" ] - interval: 1s - retries: 25 diff --git a/warehouse/integrations/postgres/testdata/less-records.csv.gz b/warehouse/integrations/postgres/testdata/less-records.csv.gz deleted file mode 100644 index 66ad9cc0b9..0000000000 Binary files a/warehouse/integrations/postgres/testdata/less-records.csv.gz and /dev/null differ diff --git a/warehouse/integrations/postgres/testdata/load.csv.gz b/warehouse/integrations/postgres/testdata/load.csv.gz deleted file mode 100644 index e8cf0221ac..0000000000 Binary files a/warehouse/integrations/postgres/testdata/load.csv.gz and /dev/null differ diff --git a/warehouse/integrations/redshift/redshift.go b/warehouse/integrations/redshift/redshift.go index cc662cc72a..a86ac26c7a 100644 --- a/warehouse/integrations/redshift/redshift.go +++ b/warehouse/integrations/redshift/redshift.go @@ -14,6 +14,10 @@ import ( "strings" "time" + "github.com/samber/lo" + + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" + "golang.org/x/exp/slices" "github.com/lib/pq" @@ -162,18 +166,18 @@ type Redshift struct { } } -type S3ManifestEntryMetadata struct { +type s3ManifestEntryMetadata struct { ContentLength int64 `json:"content_length"` } -type S3ManifestEntry struct { +type s3ManifestEntry struct { Url string `json:"url"` Mandatory bool `json:"mandatory"` - Metadata S3ManifestEntryMetadata `json:"meta"` + Metadata s3ManifestEntryMetadata `json:"meta"` } -type S3Manifest struct { - Entries []S3ManifestEntry `json:"entries"` +type s3Manifest struct { + Entries []s3ManifestEntry `json:"entries"` } type RedshiftCredentials struct { @@ -347,40 +351,61 @@ func (rs *Redshift) createSchema(ctx context.Context) (err error) { } func (rs *Redshift) generateManifest(ctx context.Context, tableName string) (string, error) { - loadFiles := rs.Uploader.GetLoadFilesMetadata(ctx, warehouseutils.GetLoadFilesOptions{Table: tableName}) - loadFiles = warehouseutils.GetS3Locations(loadFiles) - var manifest S3Manifest - for idx, loadFile := range loadFiles { - manifestEntry := S3ManifestEntry{Url: loadFile.Location, Mandatory: true} + loadFiles := warehouseutils.GetS3Locations(rs.Uploader.GetLoadFilesMetadata( + ctx, + warehouseutils.GetLoadFilesOptions{ + Table: tableName, + }, + )) + + entries := lo.Map(loadFiles, func(loadFile warehouseutils.LoadFile, index int) s3ManifestEntry { + manifestEntry := s3ManifestEntry{ + Url: loadFile.Location, + Mandatory: true, + } + // add contentLength to manifest entry if it exists - contentLength := gjson.Get(string(loadFiles[idx].Metadata), "content_length") + contentLength := gjson.Get(string(loadFile.Metadata), "content_length") if contentLength.Exists() { manifestEntry.Metadata.ContentLength = contentLength.Int() } - manifest.Entries = append(manifest.Entries, manifestEntry) + + return manifestEntry + }) + + manifestJSON, err := json.Marshal(&s3Manifest{ + Entries: entries, + }) + if err != nil { + return "", fmt.Errorf("marshalling manifest: %v", err) } - rs.logger.Infof("RS: Generated manifest for table:%s", tableName) - manifestJSON, _ := json.Marshal(&manifest) - manifestFolder := misc.RudderRedshiftManifests - dirName := "/" + manifestFolder + "/" tmpDirPath, err := misc.CreateTMPDIR() if err != nil { panic(err) } - localManifestPath := fmt.Sprintf("%v%v", tmpDirPath+dirName, misc.FastUUID().String()) + + localManifestPath := tmpDirPath + "/" + misc.RudderRedshiftManifests + "/" + misc.FastUUID().String() err = os.MkdirAll(filepath.Dir(localManifestPath), os.ModePerm) if err != nil { - panic(err) + return "", fmt.Errorf("creating manifest directory: %v", err) + } + + defer func() { + misc.RemoveFilePaths(localManifestPath) + }() + + err = os.WriteFile(localManifestPath, manifestJSON, 0o644) + if err != nil { + return "", fmt.Errorf("writing manifest to file: %v", err) } - defer misc.RemoveFilePaths(localManifestPath) - _ = os.WriteFile(localManifestPath, manifestJSON, 0o644) file, err := os.Open(localManifestPath) if err != nil { - panic(err) + return "", fmt.Errorf("opening manifest file: %v", err) } defer func() { _ = file.Close() }() + uploader, err := filemanager.New(&filemanager.Settings{ Provider: warehouseutils.S3, Config: misc.GetObjectStorageConfig(misc.ObjectStorageOptsT{ @@ -391,12 +416,17 @@ func (rs *Redshift) generateManifest(ctx context.Context, tableName string) (str }), }) if err != nil { - return "", err + return "", fmt.Errorf("creating uploader: %w", err) } - uploadOutput, err := uploader.Upload(ctx, file, manifestFolder, rs.Warehouse.Source.ID, rs.Warehouse.Destination.ID, time.Now().Format("01-02-2006"), tableName, misc.FastUUID().String()) + uploadOutput, err := uploader.Upload( + ctx, file, misc.RudderRedshiftManifests, + rs.Warehouse.Source.ID, rs.Warehouse.Destination.ID, + time.Now().Format("01-02-2006"), tableName, + misc.FastUUID().String(), + ) if err != nil { - return "", err + return "", fmt.Errorf("uploading manifest file: %w", err) } return uploadOutput.Location, nil @@ -412,17 +442,14 @@ func (rs *Redshift) dropStagingTables(ctx context.Context, stagingTableNames []s } } -func (rs *Redshift) loadTable(ctx context.Context, tableName string, tableSchemaInUpload, tableSchemaAfterUpload model.TableSchema, skipTempTableDelete bool) (string, error) { - var ( - err error - query string - stagingTableName string - rowsAffected int64 - txn *sqlmiddleware.Tx - result sql.Result - ) - - rs.logger.Infow("started loading", +func (rs *Redshift) loadTable( + ctx context.Context, + tableName string, + tableSchemaInUpload, + tableSchemaAfterUpload model.TableSchema, + skipTempTableDelete bool, +) (*types.LoadTableStats, string, error) { + log := rs.logger.With( logfield.SourceID, rs.Warehouse.Source.ID, logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, logfield.DestinationID, rs.Warehouse.Destination.ID, @@ -430,67 +457,41 @@ func (rs *Redshift) loadTable(ctx context.Context, tableName string, tableSchema logfield.WorkspaceID, rs.Warehouse.WorkspaceID, logfield.Namespace, rs.Namespace, logfield.TableName, tableName, + logfield.LoadTableStrategy, rs.loadTableStrategy(), ) + log.Infow("started loading") manifestLocation, err := rs.generateManifest(ctx, tableName) if err != nil { - return "", fmt.Errorf("generating manifest: %w", err) + return nil, "", fmt.Errorf("generating manifest: %w", err) } + log.Debugw("generated manifest", "manifestLocation", manifestLocation) - rs.logger.Infow("Generated manifest", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, tableName, - "manifestLocation", manifestLocation, + stagingTableName := warehouseutils.StagingTableName( + provider, + tableName, + tableNameLimit, ) - strKeys := warehouseutils.GetColumnsFromTableSchema(tableSchemaInUpload) - sort.Strings(strKeys) - sortedColumnNames := warehouseutils.JoinWithFormatting(strKeys, func(_ int, name string) string { - return fmt.Sprintf(`%q`, name) - }, ",") - - stagingTableName = warehouseutils.StagingTableName(provider, tableName, tableNameLimit) - _, err = rs.DB.ExecContext(ctx, fmt.Sprintf(`CREATE TABLE %[1]q.%[2]q (LIKE %[1]q.%[3]q INCLUDING DEFAULTS);`, + log.Debugw("creating staging table") + createStagingTableStmt := fmt.Sprintf(`CREATE TABLE %[1]q.%[2]q (LIKE %[1]q.%[3]q INCLUDING DEFAULTS);`, rs.Namespace, stagingTableName, tableName, - )) - if err != nil { - return "", fmt.Errorf("creating staging table: %w", err) + ) + if _, err = rs.DB.ExecContext(ctx, createStagingTableStmt); err != nil { + return nil, "", fmt.Errorf("creating staging table: %w", err) } if !skipTempTableDelete { - defer rs.dropStagingTables(ctx, []string{stagingTableName}) + defer func() { + rs.dropStagingTables(ctx, []string{stagingTableName}) + }() } - manifestS3Location, region := warehouseutils.GetS3Location(manifestLocation) - if region == "" { - region = "us-east-1" - } - - // create session token and temporary credentials - tempAccessKeyId, tempSecretAccessKey, token, err := warehouseutils.GetTemporaryS3Cred(&rs.Warehouse.Destination) + txn, err := rs.DB.BeginTx(ctx, &sql.TxOptions{}) if err != nil { - rs.logger.Warnw("getting temporary s3 credentials", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, tableName, - logfield.Error, err.Error(), - ) - return "", fmt.Errorf("getting temporary s3 credentials: %w", err) - } - - if txn, err = rs.DB.BeginTx(ctx, &sql.TxOptions{}); err != nil { - return "", fmt.Errorf("begin transaction: %w", err) + return nil, "", fmt.Errorf("begin transaction: %w", err) } defer func() { if err != nil { @@ -498,15 +499,90 @@ func (rs *Redshift) loadTable(ctx context.Context, tableName string, tableSchema } }() + strKeys := warehouseutils.GetColumnsFromTableSchema(tableSchemaInUpload) + sort.Strings(strKeys) + + log.Infow("loading data into staging table") + err = rs.copyIntoLoadTable( + ctx, txn, stagingTableName, + manifestLocation, strKeys, + ) + if err != nil { + return nil, "", fmt.Errorf("loading data into staging table: %w", err) + } + + var rowsDeleted int64 + if !slices.Contains(rs.config.skipDedupDestinationIDs, rs.Warehouse.Destination.ID) { + log.Infow("deleting from load table") + rowsDeleted, err = rs.deleteFromLoadTable( + ctx, txn, tableName, + stagingTableName, tableSchemaAfterUpload, + ) + if err != nil { + return nil, "", fmt.Errorf("delete from load table: %w", err) + } + } + + log.Infow("inserting into load table") + rowsInserted, err := rs.insertIntoLoadTable( + ctx, txn, tableName, + stagingTableName, strKeys, + ) + if err != nil { + return nil, "", fmt.Errorf("insert into: %w", err) + } + + log.Debugw("committing transaction") + if err = txn.Commit(); err != nil { + return nil, "", fmt.Errorf("commit transaction: %w", err) + } + + log.Infow("completed loading") + + return &types.LoadTableStats{ + RowsInserted: rowsInserted - rowsDeleted, + RowsUpdated: rowsDeleted, + }, stagingTableName, nil +} + +func (rs *Redshift) loadTableStrategy() string { + if slices.Contains(rs.config.skipDedupDestinationIDs, rs.Warehouse.Destination.ID) { + return "APPEND" + } + return "MERGE" +} + +func (rs *Redshift) copyIntoLoadTable( + ctx context.Context, + txn *sqlmiddleware.Tx, + stagingTableName string, + manifestLocation string, + strKeys []string, +) error { + tempAccessKeyId, tempSecretAccessKey, token, err := warehouseutils.GetTemporaryS3Cred(&rs.Warehouse.Destination) + if err != nil { + return fmt.Errorf("getting temporary s3 credentials: %w", err) + } + + manifestS3Location, region := warehouseutils.GetS3Location(manifestLocation) + if region == "" { + region = "us-east-1" + } + + sortedColumnNames := warehouseutils.JoinWithFormatting(strKeys, func(_ int, name string) string { + return fmt.Sprintf(`%q`, name) + }, ",") + + var copyStmt string if rs.Uploader.GetLoadFileType() == warehouseutils.LoadFileTypeParquet { - query = fmt.Sprintf(` - COPY %v - FROM '%s' + copyStmt = fmt.Sprintf(` + COPY %s + FROM + '%s' ACCESS_KEY_ID '%s' SECRET_ACCESS_KEY '%s' SESSION_TOKEN '%s' - MANIFEST - FORMAT PARQUET; + MANIFEST FORMAT PARQUET; `, fmt.Sprintf(`%q.%q`, rs.Namespace, stagingTableName), manifestS3Location, @@ -515,25 +591,18 @@ func (rs *Redshift) loadTable(ctx context.Context, tableName string, tableSchema token, ) } else { - query = fmt.Sprintf(` - COPY %v(%v) - FROM '%v' - CSV - GZIP + copyStmt = fmt.Sprintf(` + COPY %s(%s) + FROM + '%s' + CSV GZIP ACCESS_KEY_ID '%s' SECRET_ACCESS_KEY '%s' SESSION_TOKEN '%s' REGION '%s' DATEFORMAT 'auto' TIMEFORMAT 'auto' - MANIFEST - TRUNCATECOLUMNS - EMPTYASNULL - BLANKSASNULL - FILLRECORD - ACCEPTANYDATE - TRIMBLANKS - ACCEPTINVCHARS + MANIFEST TRUNCATECOLUMNS EMPTYASNULL BLANKSASNULL FILLRECORD ACCEPTANYDATE TRIMBLANKS ACCEPTINVCHARS COMPUPDATE OFF STATUPDATE OFF; `, @@ -547,74 +616,41 @@ func (rs *Redshift) loadTable(ctx context.Context, tableName string, tableSchema ) } - sanitisedQuery, regexErr := misc.ReplaceMultiRegex(query, map[string]string{ - "ACCESS_KEY_ID '[^']*'": "ACCESS_KEY_ID '***'", - "SECRET_ACCESS_KEY '[^']*'": "SECRET_ACCESS_KEY '***'", - "SESSION_TOKEN '[^']*'": "SESSION_TOKEN '***'", - }) - if regexErr != nil { - sanitisedQuery = "" - } - - rs.logger.Infow("copy command", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, tableName, - logfield.Query, sanitisedQuery, - ) - - if _, err := txn.ExecContext(ctx, query); err != nil { - rs.logger.Warnw("failure running copy command", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, tableName, - logfield.Query, sanitisedQuery, - logfield.Error, err.Error(), - ) - - return "", fmt.Errorf("running copy command: %w", normalizeError(err)) + if _, err := txn.ExecContext(ctx, copyStmt); err != nil { + return fmt.Errorf("running copy command: %w", normalizeError(err)) } + return nil +} - var ( - primaryKey = "id" - partitionKey = "id" - ) - +func (rs *Redshift) deleteFromLoadTable( + ctx context.Context, + txn *sqlmiddleware.Tx, + tableName string, + stagingTableName string, + tableSchemaAfterUpload model.TableSchema, +) (int64, error) { + primaryKey := "id" if column, ok := primaryKeyMap[tableName]; ok { primaryKey = column } - if column, ok := partitionKeyMap[tableName]; ok { - partitionKey = column - } - // Deduplication - // Delete rows from the table which are already present in the staging table - query = fmt.Sprintf(` - DELETE FROM - %[1]s.%[2]q - USING - %[1]s.%[3]q _source - WHERE - _source.%[4]s = %[1]s.%[2]q.%[4]s + deleteStmt := fmt.Sprintf(` + DELETE FROM + %[1]s.%[2]q + USING + %[1]s.%[3]q _source + WHERE + _source.%[4]s = %[1]s.%[2]q.%[4]s `, rs.Namespace, tableName, stagingTableName, primaryKey, ) - if rs.config.dedupWindow { if _, ok := tableSchemaAfterUpload["received_at"]; ok { - query += fmt.Sprintf(` - AND %[1]s.%[2]q.received_at > GETDATE() - INTERVAL '%[3]d HOUR' + deleteStmt += fmt.Sprintf(` + AND %[1]s.%[2]q.received_at > GETDATE() - INTERVAL '%[3]d HOUR' `, rs.Namespace, tableName, @@ -622,11 +658,10 @@ func (rs *Redshift) loadTable(ctx context.Context, tableName string, tableSchema ) } } - if tableName == warehouseutils.DiscardsTable { - query += fmt.Sprintf(` - AND _source.%[3]s = %[1]s.%[2]q.%[3]s - AND _source.%[4]s = %[1]s.%[2]q.%[4]s + deleteStmt += fmt.Sprintf(` + AND _source.%[3]s = %[1]s.%[2]q.%[3]s + AND _source.%[4]s = %[1]s.%[2]q.%[4]s `, rs.Namespace, tableName, @@ -635,64 +670,30 @@ func (rs *Redshift) loadTable(ctx context.Context, tableName string, tableSchema ) } - if !slices.Contains(rs.config.skipDedupDestinationIDs, rs.Warehouse.Destination.ID) { - rs.logger.Infow("deduplication", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, tableName, - logfield.Query, query, - ) - - if result, err = txn.ExecContext(ctx, query); err != nil { - rs.logger.Warnw("deleting from original table for dedup", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, tableName, - logfield.Query, query, - logfield.Error, err.Error(), - ) - return "", fmt.Errorf("deleting from original table for dedup: %w", normalizeError(err)) - } - - if rowsAffected, err = result.RowsAffected(); err != nil { - rs.logger.Warnw("getting rows affected for dedup", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, tableName, - logfield.Query, query, - logfield.Error, err.Error(), - ) - - return "", fmt.Errorf("getting rows affected for dedup: %w", err) - } + result, err := txn.ExecContext(ctx, deleteStmt) + if err != nil { + return 0, fmt.Errorf("deleting from main table for dedup: %w", normalizeError(err)) + } + return result.RowsAffected() +} - rs.stats.NewTaggedStat("dedup_rows", stats.CountType, stats.Tags{ - "sourceID": rs.Warehouse.Source.ID, - "sourceType": rs.Warehouse.Source.SourceDefinition.Name, - "sourceCategory": rs.Warehouse.Source.SourceDefinition.Category, - "destID": rs.Warehouse.Destination.ID, - "destType": rs.Warehouse.Destination.DestinationDefinition.Name, - "workspaceId": rs.Warehouse.WorkspaceID, - "tableName": tableName, - }).Count(int(rowsAffected)) +func (rs *Redshift) insertIntoLoadTable( + ctx context.Context, + txn *sqlmiddleware.Tx, + tableName string, + stagingTableName string, + sortedColumnKeys []string, +) (int64, error) { + partitionKey := "id" + if column, ok := partitionKeyMap[tableName]; ok { + partitionKey = column } - // Deduplication - // Insert rows from staging table to the original table - quotedColumnNames := warehouseutils.DoubleQuoteAndJoinByComma(strKeys) - query = fmt.Sprintf(` + quotedColumnNames := warehouseutils.DoubleQuoteAndJoinByComma( + sortedColumnKeys, + ) + + insertStmt := fmt.Sprintf(` INSERT INTO %[1]q.%[2]q (%[3]s) SELECT %[3]s @@ -718,58 +719,11 @@ func (rs *Redshift) loadTable(ctx context.Context, tableName string, tableSchema partitionKey, ) - rs.logger.Infow("inserting into original table", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, tableName, - logfield.Query, query, - ) - - if _, err = txn.ExecContext(ctx, query); err != nil { - rs.logger.Warnw("failed inserting into original table", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, tableName, - logfield.Error, err.Error(), - ) - - return "", fmt.Errorf("inserting into original table: %w", normalizeError(err)) - } - - if err = txn.Commit(); err != nil { - rs.logger.Warnw("committing transaction", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, tableName, - logfield.Error, err.Error(), - ) - - return "", fmt.Errorf("committing transaction: %w", err) + r, err := txn.ExecContext(ctx, insertStmt) + if err != nil { + return 0, fmt.Errorf("inserting into main table: %w", err) } - - rs.logger.Infow("completed loading", - logfield.SourceID, rs.Warehouse.Source.ID, - logfield.SourceType, rs.Warehouse.Source.SourceDefinition.Name, - logfield.DestinationID, rs.Warehouse.Destination.ID, - logfield.DestinationType, rs.Warehouse.Destination.DestinationDefinition.Name, - logfield.WorkspaceID, rs.Warehouse.WorkspaceID, - logfield.Namespace, rs.Namespace, - logfield.TableName, tableName, - ) - - return stagingTableName, nil + return r.RowsAffected() } func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { @@ -791,7 +745,7 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { logfield.Namespace, rs.Namespace, ) - identifyStagingTable, err = rs.loadTable(ctx, warehouseutils.IdentifiesTable, rs.Uploader.GetTableSchemaInUpload(warehouseutils.IdentifiesTable), rs.Uploader.GetTableSchemaInWarehouse(warehouseutils.IdentifiesTable), true) + _, identifyStagingTable, err = rs.loadTable(ctx, warehouseutils.IdentifiesTable, rs.Uploader.GetTableSchemaInUpload(warehouseutils.IdentifiesTable), rs.Uploader.GetTableSchemaInWarehouse(warehouseutils.IdentifiesTable), true) if err != nil { return map[string]error{ warehouseutils.IdentifiesTable: fmt.Errorf("loading identifies table: %w", err), @@ -807,7 +761,7 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { } if rs.config.skipComputingUserLatestTraits { - _, err := rs.loadTable(ctx, warehouseutils.UsersTable, rs.Uploader.GetTableSchemaInUpload(warehouseutils.UsersTable), rs.Uploader.GetTableSchemaInWarehouse(warehouseutils.UsersTable), false) + _, _, err := rs.loadTable(ctx, warehouseutils.UsersTable, rs.Uploader.GetTableSchemaInUpload(warehouseutils.UsersTable), rs.Uploader.GetTableSchemaInWarehouse(warehouseutils.UsersTable), false) if err != nil { return map[string]error{ warehouseutils.IdentifiesTable: nil, @@ -938,7 +892,7 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { logfield.Error, err.Error(), ) return map[string]error{ - warehouseutils.UsersTable: fmt.Errorf("deleting from original table for dedup: %w", normalizeError(err)), + warehouseutils.UsersTable: fmt.Errorf("deleting from main table for dedup: %w", normalizeError(err)), } } @@ -1398,9 +1352,15 @@ func (rs *Redshift) LoadUserTables(ctx context.Context) map[string]error { return rs.loadUserTables(ctx) } -func (rs *Redshift) LoadTable(ctx context.Context, tableName string) error { - _, err := rs.loadTable(ctx, tableName, rs.Uploader.GetTableSchemaInUpload(tableName), rs.Uploader.GetTableSchemaInWarehouse(tableName), false) - return err +func (rs *Redshift) LoadTable(ctx context.Context, tableName string) (*types.LoadTableStats, error) { + loadTableStat, _, err := rs.loadTable( + ctx, + tableName, + rs.Uploader.GetTableSchemaInUpload(tableName), + rs.Uploader.GetTableSchemaInWarehouse(tableName), + false, + ) + return loadTableStat, err } func (*Redshift) LoadIdentityMergeRulesTable(context.Context) (err error) { @@ -1415,22 +1375,6 @@ func (*Redshift) DownloadIdentityRules(context.Context, *misc.GZipWriter) (err e return } -func (rs *Redshift) GetTotalCountInTable(ctx context.Context, tableName string) (int64, error) { - var ( - total int64 - err error - sqlStatement string - ) - sqlStatement = fmt.Sprintf(` - SELECT count(*) FROM "%[1]s"."%[2]s"; - `, - rs.Namespace, - tableName, - ) - err = rs.DB.QueryRowContext(ctx, sqlStatement).Scan(&total) - return total, err -} - func (rs *Redshift) Connect(ctx context.Context, warehouse model.Warehouse) (client.Client, error) { rs.Warehouse = warehouse rs.Namespace = warehouse.Namespace diff --git a/warehouse/integrations/redshift/redshift_test.go b/warehouse/integrations/redshift/redshift_test.go index 25df6672fc..b1bc288af1 100644 --- a/warehouse/integrations/redshift/redshift_test.go +++ b/warehouse/integrations/redshift/redshift_test.go @@ -12,6 +12,13 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" + "golang.org/x/exp/slices" + + "github.com/rudderlabs/rudder-go-kit/filemanager" + mockuploader "github.com/rudderlabs/rudder-server/warehouse/internal/mocks/utils" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/stats" @@ -139,13 +146,8 @@ func TestIntegration(t *testing.T) { testhelper.EnhanceWithDefaultEnvs(t) t.Setenv("JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) t.Setenv("WAREHOUSE_JOBS_DB_PORT", strconv.Itoa(jobsDBPort)) - t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_MAX_PARALLEL_LOADS", "8") - t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_ENABLE_DELETE_BY_JOBS", "true") - t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_DEDUP_WINDOW", "true") - t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_DEDUP_WINDOW_IN_HOURS", "5") t.Setenv("RSERVER_WAREHOUSE_WEB_PORT", strconv.Itoa(httpPort)) t.Setenv("RSERVER_BACKEND_CONFIG_CONFIG_JSONPATH", workspaceConfigPath) - t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_SLOW_QUERY_THRESHOLD", "0s") svcDone := make(chan struct{}) @@ -176,6 +178,12 @@ func TestIntegration(t *testing.T) { require.NoError(t, db.Ping()) t.Run("Event flow", func(t *testing.T) { + t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_MAX_PARALLEL_LOADS", "8") + t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_ENABLE_DELETE_BY_JOBS", "true") + t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_SLOW_QUERY_THRESHOLD", "0s") + t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_DEDUP_WINDOW", "true") + t.Setenv("RSERVER_WAREHOUSE_REDSHIFT_DEDUP_WINDOW_IN_HOURS", "5") + jobsDB := testhelper.JobsDB(t, jobsDBPort) testcase := []struct { @@ -345,6 +353,477 @@ func TestIntegration(t *testing.T) { } testhelper.VerifyConfigurationTest(t, dest) }) + + t.Run("Load Table", func(t *testing.T) { + const ( + sourceID = "test_source_id" + destinationID = "test_destination_id" + workspaceID = "test_workspace_id" + ) + + namespace := testhelper.RandSchema(destType) + + t.Cleanup(func() { + require.Eventually(t, func() bool { + if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %q CASCADE;`, namespace)); err != nil { + t.Logf("error deleting schema: %v", err) + return false + } + return true + }, + time.Minute, + time.Second, + ) + }) + + schemaInUpload := model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + } + schemaInWarehouse := model.TableSchema{ + "test_bool": "boolean", + "test_datetime": "datetime", + "test_float": "float", + "test_int": "int", + "test_string": "string", + "id": "string", + "received_at": "datetime", + "extra_test_bool": "boolean", + "extra_test_datetime": "datetime", + "extra_test_float": "float", + "extra_test_int": "int", + "extra_test_string": "string", + } + + warehouse := model.Warehouse{ + Source: backendconfig.SourceT{ + ID: sourceID, + }, + Destination: backendconfig.DestinationT{ + ID: destinationID, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: destType, + }, + Config: map[string]any{ + "host": rsTestCredentials.Host, + "port": rsTestCredentials.Port, + "user": rsTestCredentials.UserName, + "password": rsTestCredentials.Password, + "database": rsTestCredentials.DbName, + "bucketName": rsTestCredentials.BucketName, + "accessKeyID": rsTestCredentials.AccessKeyID, + "accessKey": rsTestCredentials.AccessKey, + "namespace": namespace, + "syncFrequency": "30", + "enableSSE": false, + "useRudderStorage": false, + }, + }, + WorkspaceID: workspaceID, + Namespace: namespace, + } + + fm, err := filemanager.New(&filemanager.Settings{ + Provider: warehouseutils.S3, + Config: map[string]any{ + "bucketName": rsTestCredentials.BucketName, + "accessKeyID": rsTestCredentials.AccessKeyID, + "accessKey": rsTestCredentials.AccessKey, + "bucketProvider": warehouseutils.S3, + }, + }) + require.NoError(t, err) + + t.Run("schema does not exists", func(t *testing.T) { + tableName := "schema_not_exists_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + + rs := redshift.New(config.Default, logger.NOP, stats.Default) + err := rs.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + loadTableStat, err := rs.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("table does not exists", func(t *testing.T) { + tableName := "table_not_exists_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + + rs := redshift.New(config.Default, logger.NOP, stats.Default) + err := rs.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = rs.CreateSchema(ctx) + require.NoError(t, err) + + loadTableStat, err := rs.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("merge", func(t *testing.T) { + tableName := "merge_test_table" + + t.Run("without dedup", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + + d := redshift.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + t.Run("with dedup", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + + d := redshift.New(config.Default, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DedupTestRecords()) + }) + t.Run("with dedup window", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + + c := config.New() + c.Set("Warehouse.redshift.dedupWindow", true) + c.Set("Warehouse.redshift.dedupWindowInHours", 0) + + d := redshift.New(c, logger.NOP, stats.Default) + err := d.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = d.CreateSchema(ctx) + require.NoError(t, err) + + err = d.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := d.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, d.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %s.%s + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DedupTwiceTestRecords()) + }) + }) + t.Run("append", func(t *testing.T) { + tableName := "append_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + + c := config.New() + c.Set("Warehouse.redshift.skipDedupDestinationIDs", []string{destinationID}) + + rs := redshift.New(c, logger.NOP, stats.Default) + err := rs.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = rs.CreateSchema(ctx) + require.NoError(t, err) + + err = rs.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := rs.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = rs.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, rs.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.AppendTestRecords()) + }) + t.Run("load file does not exists", func(t *testing.T) { + tableName := "load_file_not_exists_test_table" + + loadFiles := []warehouseutils.LoadFile{{ + Location: "https://bucket.s3.amazonaws.com/rudder-warehouse-load-objects/load_file_not_exists_test_table/test_source_id/0ef75cb0-3fd0-4408-98b9-2bea9e476916-load_file_not_exists_test_table/load.csv.gz", + }} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + + rs := redshift.New(config.Default, logger.NOP, stats.Default) + err := rs.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = rs.CreateSchema(ctx) + require.NoError(t, err) + + err = rs.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := rs.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("mismatch in number of columns", func(t *testing.T) { + tableName := "mismatch_columns_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + + rs := redshift.New(config.Default, logger.NOP, stats.Default) + err := rs.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = rs.CreateSchema(ctx) + require.NoError(t, err) + + err = rs.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := rs.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("mismatch in schema", func(t *testing.T) { + tableName := "mismatch_schema_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, warehouseutils.LoadFileTypeCsv) + + rs := redshift.New(config.Default, logger.NOP, stats.Default) + err := rs.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = rs.CreateSchema(ctx) + require.NoError(t, err) + + err = rs.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := rs.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("discards", func(t *testing.T) { + tableName := warehouseutils.DiscardsTable + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) + + loadFiles := []warehouseutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, warehouseutils.DiscardsSchema, warehouseutils.DiscardsSchema, warehouseutils.LoadFileTypeCsv) + + rs := redshift.New(config.Default, logger.NOP, stats.Default) + err := rs.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = rs.CreateSchema(ctx) + require.NoError(t, err) + + err = rs.CreateTable(ctx, tableName, warehouseutils.DiscardsSchema) + require.NoError(t, err) + + loadTableStat, err := rs.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(6)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, rs.DB.DB, + fmt.Sprintf(` + SELECT + column_name, + column_value, + received_at, + row_id, + table_name, + uuid_ts + FROM + %q.%q + ORDER BY row_id ASC; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DiscardTestRecords()) + }) + t.Run("parquet", func(t *testing.T) { + tableName := "parquet_test_table" + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.parquet", tableName) + + fileStat, err := os.Stat("../testdata/load.parquet") + require.NoError(t, err) + + loadFiles := []warehouseutils.LoadFile{{ + Location: uploadOutput.Location, + Metadata: json.RawMessage(fmt.Sprintf(`{"content_length": %d}`, fileStat.Size())), + }} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInUpload, warehouseutils.LoadFileTypeParquet) + + rs := redshift.New(config.Default, logger.NOP, stats.Default) + err = rs.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = rs.CreateSchema(ctx) + require.NoError(t, err) + + err = rs.CreateTable(ctx, tableName, schemaInUpload) + require.NoError(t, err) + + loadTableStat, err := rs.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, rs.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + }) } func TestCheckAndIgnoreColumnAlreadyExistError(t *testing.T) { @@ -489,3 +968,28 @@ func TestRedshift_AlterColumn(t *testing.T) { }) } } + +func newMockUploader( + t testing.TB, + loadFiles []warehouseutils.LoadFile, + tableName string, + schemaInUpload model.TableSchema, + schemaInWarehouse model.TableSchema, + loadFileType string, +) warehouseutils.Uploader { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockUploader := mockuploader.NewMockUploader(ctrl) + mockUploader.EXPECT().UseRudderStorage().Return(false).AnyTimes() + mockUploader.EXPECT().GetLoadFilesMetadata(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, options warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { + return slices.Clone(loadFiles) + }, + ).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInUpload(tableName).Return(schemaInUpload).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInWarehouse(tableName).Return(schemaInWarehouse).AnyTimes() + mockUploader.EXPECT().GetLoadFileType().Return(loadFileType).AnyTimes() + + return mockUploader +} diff --git a/warehouse/integrations/snowflake/snowflake.go b/warehouse/integrations/snowflake/snowflake.go index 6c41c66203..eed8ab2aba 100644 --- a/warehouse/integrations/snowflake/snowflake.go +++ b/warehouse/integrations/snowflake/snowflake.go @@ -9,9 +9,12 @@ import ( "fmt" "regexp" "sort" + "strconv" "strings" "time" + "github.com/rudderlabs/rudder-server/warehouse/integrations/types" + "github.com/samber/lo" snowflake "github.com/snowflakedb/gosnowflake" @@ -338,7 +341,12 @@ func (sf *Snowflake) DeleteBy(ctx context.Context, tableNames []string, params w return nil } -func (sf *Snowflake) loadTable(ctx context.Context, tableName string, tableSchemaInUpload model.TableSchema, skipClosingDBSession bool) (tableLoadResp, error) { +func (sf *Snowflake) loadTable( + ctx context.Context, + tableName string, + tableSchemaInUpload model.TableSchema, + skipClosingDBSession bool, +) (*types.LoadTableStats, *tableLoadResp, error) { var ( db *sqlmw.DB err error @@ -357,7 +365,7 @@ func (sf *Snowflake) loadTable(ctx context.Context, tableName string, tableSchem log.Infow("started loading") if db, err = sf.connect(ctx, optionalCreds{schemaName: sf.Namespace}); err != nil { - return tableLoadResp{}, fmt.Errorf("connect: %w", err) + return nil, nil, fmt.Errorf("connect: %w", err) } if !skipClosingDBSession { @@ -365,51 +373,54 @@ func (sf *Snowflake) loadTable(ctx context.Context, tableName string, tableSchem } schemaIdentifier := sf.schemaIdentifier() - stagingTableName := whutils.StagingTableName(provider, tableName, tableNameLimit) + stagingTableName := whutils.StagingTableName( + provider, + tableName, + tableNameLimit, + ) + strKeys := sf.getSortedColumnsFromTableSchema(tableSchemaInUpload) sortedColumnNames := sf.joinColumnsWithFormatting(strKeys, "%q") // Truncating the columns by default to avoid size limitation errors // https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions if sf.ShouldAppend() { - err = sf.copyInto(ctx, db, schemaIdentifier, tableName, sortedColumnNames, tableName, log) + log.Infow("copying data into main table") + loadTableStats, err := sf.copyInto(ctx, db, schemaIdentifier, tableName, sortedColumnNames, tableName) if err != nil { - return tableLoadResp{}, err + return nil, nil, fmt.Errorf("copying data into main table: %w", err) } log.Infow("completed loading") - return tableLoadResp{db: db, stagingTable: tableName}, nil + + resp := &tableLoadResp{ + db: db, + stagingTable: tableName, + } + return loadTableStats, resp, nil } - sqlStatement := fmt.Sprintf(`CREATE TEMPORARY TABLE %[1]s.%[2]q LIKE %[1]s.%[3]q;`, + log.Debugw("creating staging table") + createStagingTableStmt := fmt.Sprintf(`CREATE TEMPORARY TABLE %[1]s.%[2]q LIKE %[1]s.%[3]q;`, schemaIdentifier, stagingTableName, tableName, ) - - log.Debugw("creating temporary table", lf.StagingTableName, stagingTableName) - if _, err = db.ExecContext(ctx, sqlStatement); err != nil { - sf.logger.Warnw("failure creating temporary table", - lf.StagingTableName, stagingTableName, - lf.Error, err.Error(), - ) - return tableLoadResp{}, fmt.Errorf("create temporary table: %w", err) + if _, err = db.ExecContext(ctx, createStagingTableStmt); err != nil { + return nil, nil, fmt.Errorf("create staging table: %w", err) } - err = sf.copyInto(ctx, db, schemaIdentifier, tableName, sortedColumnNames, stagingTableName, log) + log.Infow("loading data into staging table") + _, err = sf.copyInto(ctx, db, schemaIdentifier, tableName, sortedColumnNames, stagingTableName) if err != nil { - return tableLoadResp{}, err + return nil, nil, fmt.Errorf("loading data into staging table: %w", err) } - duplicates, err := sf.sampleDuplicateMessages( - ctx, - db, - tableName, - stagingTableName, - ) + duplicates, err := sf.sampleDuplicateMessages(ctx, db, tableName, stagingTableName) if err != nil { log.Warnw("failed to sample duplicate rows", lf.Error, err.Error()) - } else if len(duplicates) > 0 { + } + if len(duplicates) > 0 { uploadID, _ := whutils.UploadIDFromCtx(ctx) formattedDuplicateMessages := lo.Map(duplicates, func(item duplicateMessage, index int) string { @@ -418,19 +429,38 @@ func (sf *Snowflake) loadTable(ctx context.Context, tableName string, tableSchem log.Infow("sample duplicate rows", lf.UploadJobID, uploadID, lf.SampleDuplicateMessages, formattedDuplicateMessages) } - var ( - primaryKey = "ID" - partitionKey = `"ID"` - keepLatestRecordOnDedup = sf.Uploader.ShouldOnDedupUseNewRecord() - - additionalJoinClause string - inserted int64 - updated int64 + log.Infow("merge data into load table") + loadTableStats, err := sf.mergeIntoLoadTable( + ctx, db, schemaIdentifier, tableName, stagingTableName, + sortedColumnNames, strKeys, ) + if err != nil { + return nil, nil, fmt.Errorf("merge into load table: %w", err) + } + + log.Infow("completed loading") + + resp := &tableLoadResp{ + db: db, + stagingTable: stagingTableName, + } + return loadTableStats, resp, nil +} +func (sf *Snowflake) mergeIntoLoadTable( + ctx context.Context, + db *sqlmw.DB, + schemaIdentifier, + tableName string, + stagingTableName string, + sortedColumnNames string, + strKeys []string, +) (*types.LoadTableStats, error) { + primaryKey := "ID" if column, ok := primaryKeyMap[tableName]; ok { primaryKey = column } + partitionKey := `"ID"` if column, ok := partitionKeyMap[tableName]; ok { partitionKey = column } @@ -438,62 +468,57 @@ func (sf *Snowflake) loadTable(ctx context.Context, tableName string, tableSchem stagingColumnNames := sf.joinColumnsWithFormatting(strKeys, `staging.%q`) columnsWithValues := sf.joinColumnsWithFormatting(strKeys, `original.%[1]q = staging.%[1]q`) + var additionalJoinClause string if tableName == discardsTable { additionalJoinClause = fmt.Sprintf(`AND original.%[1]q = staging.%[1]q AND original.%[2]q = staging.%[2]q`, "TABLE_NAME", "COLUMN_NAME") } updateSet := columnsWithValues - if !keepLatestRecordOnDedup { + if !sf.Uploader.ShouldOnDedupUseNewRecord() { // This is being added in order to get the updates count updateSet = fmt.Sprintf(`original.%[1]q = original.%[1]q`, strKeys[0]) } - sqlStatement = sf.mergeIntoStmt( - schemaIdentifier, - tableName, - stagingTableName, - partitionKey, - primaryKey, - additionalJoinClause, - sortedColumnNames, - stagingColumnNames, + mergeStmt := fmt.Sprintf(`MERGE INTO %[1]s.%[2]q AS original USING ( + SELECT * + FROM + ( + SELECT *, + row_number() OVER ( + PARTITION BY %[4]s + ORDER BY + RECEIVED_AT DESC + ) AS _rudder_staging_row_number + FROM + %[1]s.%[3]q + ) AS q + WHERE + _rudder_staging_row_number = 1 + ) AS staging ON ( + original.%[5]q = staging.%[5]q %[6]s + ) + WHEN NOT MATCHED THEN + INSERT (%[7]s) VALUES (%[8]s) + WHEN MATCHED THEN + UPDATE SET %[9]s;`, + schemaIdentifier, tableName, stagingTableName, + partitionKey, primaryKey, additionalJoinClause, + sortedColumnNames, stagingColumnNames, updateSet, ) - log.Infow("deduplication", lf.Query, sqlStatement) - - row := db.QueryRowContext(ctx, sqlStatement) - if row.Err() != nil { - log.Warnw("failure running deduplication", - lf.Query, sqlStatement, - lf.Error, row.Err().Error(), - ) - return tableLoadResp{}, fmt.Errorf("merge into table: %w", row.Err()) - } - - if err = row.Scan(&inserted, &updated); err != nil { - log.Warnw("getting rows affected for dedup", - lf.Query, sqlStatement, - lf.Error, err.Error(), - ) - return tableLoadResp{}, fmt.Errorf("getting rows affected for dedup: %w", err) + var rowsInserted, rowsUpdated int64 + err := db.QueryRowContext(ctx, mergeStmt).Scan( + &rowsInserted, + &rowsUpdated, + ) + if err != nil { + return nil, fmt.Errorf("executing merge command: %w", err) } - sf.stats.NewTaggedStat("dedup_rows", stats.CountType, stats.Tags{ - "sourceID": sf.Warehouse.Source.ID, - "sourceType": sf.Warehouse.Source.SourceDefinition.Name, - "sourceCategory": sf.Warehouse.Source.SourceDefinition.Category, - "destID": sf.Warehouse.Destination.ID, - "destType": sf.Warehouse.Destination.DestinationDefinition.Name, - "workspaceId": sf.Warehouse.WorkspaceID, - "tableName": tableName, - }).Count(int(updated)) - - log.Infow("completed loading") - - return tableLoadResp{ - db: db, - stagingTable: stagingTableName, + return &types.LoadTableStats{ + RowsInserted: rowsInserted, + RowsUpdated: rowsUpdated, }, nil } @@ -578,19 +603,22 @@ func (sf *Snowflake) sampleDuplicateMessages( func (sf *Snowflake) copyInto( ctx context.Context, db *sqlmw.DB, - schemaIdentifier, - tableName, - sortedColumnNames, + schemaIdentifier string, + tableName string, + sortedColumnNames string, copyTargetTable string, - log logger.Logger, -) error { +) (*types.LoadTableStats, error) { csvObjectLocation, err := sf.Uploader.GetSampleLoadFileLocation(ctx, tableName) if err != nil { - return fmt.Errorf("getting sample load file location: %w", err) + return nil, fmt.Errorf("getting sample load file location: %w", err) } - loadFolder := whutils.GetObjectFolder(sf.ObjectStorage, csvObjectLocation) - sqlStatement := fmt.Sprintf( + loadFolder := whutils.GetObjectFolder( + sf.ObjectStorage, + csvObjectLocation, + ) + + copyStmt := fmt.Sprintf( `COPY INTO %s.%q(%v) FROM @@ -604,62 +632,58 @@ func (sf *Snowflake) copyInto( sf.authString(), ) - sanitisedQuery, regexErr := misc.ReplaceMultiRegex(sqlStatement, map[string]string{ - "AWS_KEY_ID='[^']*'": "AWS_KEY_ID='***'", - "AWS_SECRET_KEY='[^']*'": "AWS_SECRET_KEY='***'", - "AWS_TOKEN='[^']*'": "AWS_TOKEN='***'", - }) - if regexErr != nil { - sanitisedQuery = "" + rows, err := db.QueryContext(ctx, copyStmt) + if err != nil { + return nil, fmt.Errorf("copy into table: %w", err) } - log.Infow("copy command", lf.Query, sanitisedQuery) + defer func() { + _ = rows.Close() + }() - if _, err := db.ExecContext(ctx, sqlStatement); err != nil { - log.Warnw("failure running COPY command", - lf.Query, sanitisedQuery, - lf.Error, err.Error(), - ) - return fmt.Errorf("copy into table: %w", err) + columns, err := rows.Columns() + if err != nil { + return nil, fmt.Errorf("getting columns: %w", err) } - return nil -} + _, index, found := lo.FindIndexOf(columns, func(item string) bool { + return strings.ToLower(item) == "rows_loaded" + }) + if !found { + sf.logger.Warnw("rows_loaded column not found in copy command result", "columns", columns) + return &types.LoadTableStats{}, nil + } -func (sf *Snowflake) mergeIntoStmt( - schemaIdentifier, tableName, stagingTableName, - partitionKey, - primaryKey, additionalJoinClause, - sortedColumnNames, stagingColumnNames, - updateSet string, -) string { - return fmt.Sprintf(`MERGE INTO %[1]s.%[2]q AS original USING ( - SELECT * - FROM - ( - SELECT *, - row_number() OVER ( - PARTITION BY %[4]s - ORDER BY - RECEIVED_AT DESC - ) AS _rudder_staging_row_number - FROM - %[1]s.%[3]q - ) AS q - WHERE - _rudder_staging_row_number = 1 - ) AS staging ON ( - original.%[5]q = staging.%[5]q %[6]s - ) - WHEN NOT MATCHED THEN - INSERT (%[7]s) VALUES (%[8]s) - WHEN MATCHED THEN - UPDATE SET %[9]s;`, - schemaIdentifier, tableName, stagingTableName, - partitionKey, - primaryKey, additionalJoinClause, - sortedColumnNames, stagingColumnNames, - updateSet, - ) + var rowsInserted int64 + for rows.Next() { + resultSet := make([]any, len(columns)) + resultSetPtrs := make([]any, len(columns)) + for i := 0; i < len(columns); i++ { + resultSetPtrs[i] = &resultSet[i] + } + + if err := rows.Scan(resultSetPtrs...); err != nil { + return nil, fmt.Errorf("scanning row: %w", err) + } + + countString, ok := resultSet[index].(string) + if !ok { + return nil, fmt.Errorf("count not a string") + } + count, err := strconv.Atoi(countString) + if err != nil { + return nil, fmt.Errorf("converting rows loaded: %w", err) + } + + rowsInserted += int64(count) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterating over rows: %w", err) + } + + loadTableStats := &types.LoadTableStats{ + RowsInserted: rowsInserted, + } + return loadTableStats, nil } func (sf *Snowflake) LoadIdentityMergeRulesTable(ctx context.Context) error { @@ -736,11 +760,11 @@ func (sf *Snowflake) LoadIdentityMappingsTable(ctx context.Context) error { ) log = log.With(lf.StagingTableName, stagingTableName) - log.Infow("Creating temporary table", lf.Query, sqlStatement) + log.Infow("Creating staging table", lf.Query, sqlStatement) _, err = db.ExecContext(ctx, sqlStatement) if err != nil { - log.Errorw("Error creating temporary table", + log.Errorw("Error creating staging table", lf.Query, sqlStatement, lf.Error, err.Error(), ) @@ -852,7 +876,7 @@ func (sf *Snowflake) LoadUserTables(ctx context.Context) map[string]error { ) log.Infow("started loading for identifies and users tables") - resp, err := sf.loadTable(ctx, identifiesTable, identifiesSchema, true) + _, resp, err := sf.loadTable(ctx, identifiesTable, identifiesSchema, true) if err != nil { return map[string]error{ identifiesTable: fmt.Errorf("loading table %s: %w", identifiesTable, err), @@ -882,9 +906,7 @@ func (sf *Snowflake) LoadUserTables(ctx context.Context) map[string]error { strKeys := sf.getSortedColumnsFromTableSchema(identifiesSchema) sortedColumnNames := sf.joinColumnsWithFormatting(strKeys, "%q") - err = sf.copyInto( - ctx, resp.db, schemaIdentifier, identifiesTable, sortedColumnNames, tmpIdentifiesStagingTable, log, - ) + _, err = sf.copyInto(ctx, resp.db, schemaIdentifier, identifiesTable, sortedColumnNames, tmpIdentifiesStagingTable) if err != nil { return map[string]error{ identifiesTable: fmt.Errorf("loading identifies temp table %s: %w", identifiesTable, err), @@ -1418,25 +1440,14 @@ func (sf *Snowflake) Cleanup(context.Context) { } } -func (sf *Snowflake) LoadTable(ctx context.Context, tableName string) error { - _, err := sf.loadTable(ctx, tableName, sf.Uploader.GetTableSchemaInUpload(tableName), false) - return err -} - -func (sf *Snowflake) GetTotalCountInTable(ctx context.Context, tableName string) (int64, error) { - var ( - total int64 - err error - sqlStatement string - ) - sqlStatement = fmt.Sprintf(` - SELECT count(*) FROM %[1]s.%[2]q; - `, - sf.schemaIdentifier(), +func (sf *Snowflake) LoadTable(ctx context.Context, tableName string) (*types.LoadTableStats, error) { + loadTableStat, _, err := sf.loadTable( + ctx, tableName, + sf.Uploader.GetTableSchemaInUpload(tableName), + false, ) - err = sf.DB.QueryRowContext(ctx, sqlStatement).Scan(&total) - return total, err + return loadTableStat, err } func (sf *Snowflake) Connect(ctx context.Context, warehouse model.Warehouse) (client.Client, error) { diff --git a/warehouse/integrations/snowflake/snowflake_test.go b/warehouse/integrations/snowflake/snowflake_test.go index 9065b876a1..2e9c614b3b 100644 --- a/warehouse/integrations/snowflake/snowflake_test.go +++ b/warehouse/integrations/snowflake/snowflake_test.go @@ -12,6 +12,12 @@ import ( "testing" "time" + "github.com/samber/lo" + "golang.org/x/exp/slices" + + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/golang/mock/gomock" sfdb "github.com/snowflakedb/gosnowflake" "github.com/stretchr/testify/require" @@ -61,7 +67,7 @@ func getSnowflakeTestCredentials(key string) (*testCredentials, error) { var credentials testCredentials err := json.Unmarshal([]byte(cred), &credentials) if err != nil { - return nil, fmt.Errorf("failed to snowflake redshift test credentials: %w", err) + return nil, fmt.Errorf("failed to unmarshal snowflake test credentials: %w", err) } return &credentials, nil } @@ -511,6 +517,448 @@ func TestIntegration(t *testing.T) { } testhelper.VerifyConfigurationTest(t, dest) }) + + t.Run("Load Table", func(t *testing.T) { + const ( + sourceID = "test_source_id" + destinationID = "test_destination_id" + workspaceID = "test_workspace_id" + ) + + namespace := testhelper.RandSchema(destType) + + ctx := context.Background() + + urlConfig := sfdb.Config{ + Account: credentials.Account, + User: credentials.User, + Role: credentials.Role, + Password: credentials.Password, + Database: credentials.Database, + Warehouse: credentials.Warehouse, + } + + dsn, err := sfdb.DSN(&urlConfig) + require.NoError(t, err) + + db := getSnowflakeDB(t, dsn) + require.NoError(t, db.Ping()) + + t.Cleanup(func() { + require.Eventually(t, func() bool { + if _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA %q CASCADE;`, namespace)); err != nil { + t.Logf("error deleting schema: %v", err) + return false + } + return true + }, + time.Minute, + time.Second, + ) + }) + + schemaInUpload := model.TableSchema{ + "TEST_BOOL": "boolean", + "TEST_DATETIME": "datetime", + "TEST_FLOAT": "float", + "TEST_INT": "int", + "TEST_STRING": "string", + "ID": "string", + "RECEIVED_AT": "datetime", + } + schemaInWarehouse := model.TableSchema{ + "TEST_BOOL": "boolean", + "TEST_DATETIME": "datetime", + "TEST_FLOAT": "float", + "TEST_INT": "int", + "TEST_STRING": "string", + "ID": "string", + "RECEIVED_AT": "datetime", + "EXTRA_TEST_BOOL": "boolean", + "EXTRA_TEST_DATETIME": "datetime", + "EXTRA_TEST_FLOAT": "float", + "EXTRA_TEST_INT": "int", + "EXTRA_TEST_STRING": "string", + } + + warehouse := model.Warehouse{ + Source: backendconfig.SourceT{ + ID: sourceID, + }, + Destination: backendconfig.DestinationT{ + ID: destinationID, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: destType, + }, + Config: map[string]any{ + "account": credentials.Account, + "database": credentials.Database, + "warehouse": credentials.Warehouse, + "user": credentials.User, + "password": credentials.Password, + "cloudProvider": "AWS", + "bucketName": credentials.BucketName, + "storageIntegration": "", + "accessKeyID": credentials.AccessKeyID, + "accessKey": credentials.AccessKey, + "namespace": namespace, + }, + }, + WorkspaceID: workspaceID, + Namespace: namespace, + } + + fm, err := filemanager.New(&filemanager.Settings{ + Provider: whutils.S3, + Config: map[string]any{ + "bucketName": credentials.BucketName, + "accessKeyID": credentials.AccessKeyID, + "accessKey": credentials.AccessKey, + "bucketProvider": whutils.S3, + }, + }) + require.NoError(t, err) + + t.Run("schema does not exists", func(t *testing.T) { + tableName := whutils.ToProviderCase(whutils.SNOWFLAKE, "schema_not_exists_test_table") + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, whutils.LoadFileTypeCsv, false, false) + + sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + require.NoError(t, err) + err = sf.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + loadTableStat, err := sf.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("table does not exists", func(t *testing.T) { + tableName := whutils.ToProviderCase(whutils.SNOWFLAKE, "table_not_exists_test_table") + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, whutils.LoadFileTypeCsv, false, false) + + sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + require.NoError(t, err) + err = sf.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = sf.CreateSchema(ctx) + require.NoError(t, err) + + loadTableStat, err := sf.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("merge", func(t *testing.T) { + tableName := whutils.ToProviderCase(whutils.SNOWFLAKE, "merge_test_table") + + t.Run("without dedup", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, whutils.LoadFileTypeCsv, false, false) + + c := config.New() + c.Set("Warehouse.snowflake.debugDuplicateWorkspaceIDs", []string{workspaceID}) + c.Set("Warehouse.snowflake.debugDuplicateIntervalInDays", 1000) + c.Set("Warehouse.snowflake.debugDuplicateTables", []string{whutils.ToProviderCase( + whutils.SNOWFLAKE, + tableName, + )}) + + sf, err := snowflake.New(c, logger.NOP, stats.Default) + require.NoError(t, err) + err = sf.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = sf.CreateSchema(ctx) + require.NoError(t, err) + + err = sf.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := sf.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + loadTableStat, err = sf.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, sf.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + t.Run("with dedup use new record", func(t *testing.T) { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/dedup.csv.gz", tableName) + + loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, whutils.LoadFileTypeCsv, false, true) + + sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + require.NoError(t, err) + err = sf.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = sf.CreateSchema(ctx) + require.NoError(t, err) + + err = sf.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := sf.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(14)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, db, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DedupTestRecords()) + }) + }) + t.Run("append", func(t *testing.T) { + tableName := whutils.ToProviderCase(whutils.SNOWFLAKE, "append_test_table") + + run := func() { + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/load.csv.gz", tableName) + + loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, whutils.LoadFileTypeCsv, true, false) + + c := config.New() + c.Set("Warehouse.snowflake.loadTableStrategy", "APPEND") + + sf, err := snowflake.New(c, logger.NOP, stats.Default) + require.NoError(t, err) + err = sf.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = sf.CreateSchema(ctx) + require.NoError(t, err) + + err = sf.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + t.Run("loading once should copy everything", func(t *testing.T) { + loadTableStat, err := sf.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + }) + t.Run("loading twice should not copy anything", func(t *testing.T) { + loadTableStat, err := sf.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(0)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + }) + } + + run() + run() + + records := testhelper.RetrieveRecordsFromWarehouse(t, db, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.AppendTestRecords()) + }) + t.Run("load file does not exists", func(t *testing.T) { + tableName := whutils.ToProviderCase(whutils.SNOWFLAKE, "load_file_not_exists_test_table") + + loadFiles := []whutils.LoadFile{{ + Location: "https://bucket.s3.amazonaws.com/rudder-warehouse-load-objects/load_file_not_exists_test_table/test_source_id/0ef75cb0-3fd0-4408-98b9-2bea9e476916-load_file_not_exists_test_table/load.csv.gz", + }} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, whutils.LoadFileTypeCsv, false, false) + + sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + require.NoError(t, err) + err = sf.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = sf.CreateSchema(ctx) + require.NoError(t, err) + + err = sf.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := sf.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("mismatch in number of columns", func(t *testing.T) { + tableName := whutils.ToProviderCase(whutils.SNOWFLAKE, "mismatch_columns_test_table") + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-columns.csv.gz", tableName) + + loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, whutils.LoadFileTypeCsv, false, false) + + sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + require.NoError(t, err) + err = sf.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = sf.CreateSchema(ctx) + require.NoError(t, err) + + err = sf.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := sf.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, sf.DB.DB, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM + %q.%q + ORDER BY + id; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.SampleTestRecords()) + }) + t.Run("mismatch in schema", func(t *testing.T) { + tableName := whutils.ToProviderCase(whutils.SNOWFLAKE, "mismatch_schema_test_table") + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/mismatch-schema.csv.gz", tableName) + + loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse, whutils.LoadFileTypeCsv, false, false) + + sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + require.NoError(t, err) + err = sf.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = sf.CreateSchema(ctx) + require.NoError(t, err) + + err = sf.CreateTable(ctx, tableName, schemaInWarehouse) + require.NoError(t, err) + + loadTableStat, err := sf.LoadTable(ctx, tableName) + require.Error(t, err) + require.Nil(t, loadTableStat) + }) + t.Run("discards", func(t *testing.T) { + tableName := whutils.ToProviderCase(whutils.SNOWFLAKE, whutils.DiscardsTable) + + uploadOutput := testhelper.UploadLoadFile(t, fm, "../testdata/discards.csv.gz", tableName) + + discardsSchema := lo.MapKeys(whutils.DiscardsSchema, func(_, key string) string { + return whutils.ToProviderCase(whutils.SNOWFLAKE, key) + }) + + loadFiles := []whutils.LoadFile{{Location: uploadOutput.Location}} + mockUploader := newMockUploader(t, loadFiles, tableName, discardsSchema, discardsSchema, whutils.LoadFileTypeCsv, false, false) + + sf, err := snowflake.New(config.Default, logger.NOP, stats.Default) + require.NoError(t, err) + err = sf.Setup(ctx, warehouse, mockUploader) + require.NoError(t, err) + + err = sf.CreateSchema(ctx) + require.NoError(t, err) + + err = sf.CreateTable(ctx, tableName, discardsSchema) + require.NoError(t, err) + + loadTableStat, err := sf.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(6)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := testhelper.RetrieveRecordsFromWarehouse(t, sf.DB.DB, + fmt.Sprintf(` + SELECT + COLUMN_NAME, + COLUMN_VALUE, + RECEIVED_AT, + ROW_ID, + TABLE_NAME, + UUID_TS + FROM + %q.%q + ORDER BY ROW_ID ASC; + `, + namespace, + tableName, + ), + ) + require.Equal(t, records, testhelper.DiscardTestRecords()) + }) + }) } func TestSnowflake_ShouldAppend(t *testing.T) { @@ -569,6 +1017,36 @@ func TestSnowflake_ShouldAppend(t *testing.T) { } } +func newMockUploader( + t testing.TB, + loadFiles []whutils.LoadFile, + tableName string, + schemaInUpload model.TableSchema, + schemaInWarehouse model.TableSchema, + loadFileType string, + canAppend bool, + dedupUseNewRecord bool, +) whutils.Uploader { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockUploader := mockuploader.NewMockUploader(ctrl) + mockUploader.EXPECT().UseRudderStorage().Return(false).AnyTimes() + mockUploader.EXPECT().CanAppend().Return(canAppend).AnyTimes() + mockUploader.EXPECT().ShouldOnDedupUseNewRecord().Return(dedupUseNewRecord).AnyTimes() + mockUploader.EXPECT().GetLoadFilesMetadata(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, options whutils.GetLoadFilesOptions) []whutils.LoadFile { + return slices.Clone(loadFiles) + }, + ).AnyTimes() + mockUploader.EXPECT().GetSampleLoadFileLocation(gomock.Any(), gomock.Any()).Return(loadFiles[0].Location, nil).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInUpload(tableName).Return(schemaInUpload).AnyTimes() + mockUploader.EXPECT().GetTableSchemaInWarehouse(tableName).Return(schemaInWarehouse).AnyTimes() + mockUploader.EXPECT().GetLoadFileType().Return(loadFileType).AnyTimes() + + return mockUploader +} + func getSnowflakeDB(t testing.TB, dsn string) *sql.DB { t.Helper() db, err := sql.Open("snowflake", dsn) diff --git a/warehouse/integrations/testdata/dedup.csv.gz b/warehouse/integrations/testdata/dedup.csv.gz new file mode 100644 index 0000000000..19788f2c70 Binary files /dev/null and b/warehouse/integrations/testdata/dedup.csv.gz differ diff --git a/warehouse/integrations/testdata/dedup.json.gz b/warehouse/integrations/testdata/dedup.json.gz new file mode 100644 index 0000000000..f07ef9600d Binary files /dev/null and b/warehouse/integrations/testdata/dedup.json.gz differ diff --git a/warehouse/integrations/postgres/testdata/discards.csv.gz b/warehouse/integrations/testdata/discards.csv.gz similarity index 84% rename from warehouse/integrations/postgres/testdata/discards.csv.gz rename to warehouse/integrations/testdata/discards.csv.gz index 1a0f93bef7..9597cd7e9f 100644 Binary files a/warehouse/integrations/postgres/testdata/discards.csv.gz and b/warehouse/integrations/testdata/discards.csv.gz differ diff --git a/warehouse/integrations/testdata/discards.json.gz b/warehouse/integrations/testdata/discards.json.gz new file mode 100644 index 0000000000..811b3035bc Binary files /dev/null and b/warehouse/integrations/testdata/discards.json.gz differ diff --git a/warehouse/integrations/testdata/load.csv.gz b/warehouse/integrations/testdata/load.csv.gz new file mode 100644 index 0000000000..7a6a727257 Binary files /dev/null and b/warehouse/integrations/testdata/load.csv.gz differ diff --git a/warehouse/integrations/testdata/load.json.gz b/warehouse/integrations/testdata/load.json.gz new file mode 100644 index 0000000000..9ad30c074b Binary files /dev/null and b/warehouse/integrations/testdata/load.json.gz differ diff --git a/warehouse/integrations/testdata/load.parquet b/warehouse/integrations/testdata/load.parquet new file mode 100644 index 0000000000..b6f7dea543 Binary files /dev/null and b/warehouse/integrations/testdata/load.parquet differ diff --git a/warehouse/integrations/testdata/mismatch-columns.csv.gz b/warehouse/integrations/testdata/mismatch-columns.csv.gz new file mode 100644 index 0000000000..9ab53063ae Binary files /dev/null and b/warehouse/integrations/testdata/mismatch-columns.csv.gz differ diff --git a/warehouse/integrations/testdata/mismatch-columns.json.gz b/warehouse/integrations/testdata/mismatch-columns.json.gz new file mode 100644 index 0000000000..b314404ddf Binary files /dev/null and b/warehouse/integrations/testdata/mismatch-columns.json.gz differ diff --git a/warehouse/integrations/testdata/mismatch-schema.csv.gz b/warehouse/integrations/testdata/mismatch-schema.csv.gz new file mode 100644 index 0000000000..d7291d6759 Binary files /dev/null and b/warehouse/integrations/testdata/mismatch-schema.csv.gz differ diff --git a/warehouse/integrations/testdata/mismatch-schema.json.gz b/warehouse/integrations/testdata/mismatch-schema.json.gz new file mode 100644 index 0000000000..2265b8fb28 Binary files /dev/null and b/warehouse/integrations/testdata/mismatch-schema.json.gz differ diff --git a/warehouse/integrations/testhelper/setup.go b/warehouse/integrations/testhelper/setup.go index 71dca72506..91523e0b7f 100644 --- a/warehouse/integrations/testhelper/setup.go +++ b/warehouse/integrations/testhelper/setup.go @@ -1,13 +1,21 @@ package testhelper import ( + "context" "database/sql" "fmt" + "os" "strconv" "strings" "testing" "time" + "github.com/google/uuid" + "github.com/samber/lo" + "github.com/spf13/cast" + + "github.com/rudderlabs/rudder-go-kit/filemanager" + "github.com/rudderlabs/rudder-go-kit/testhelper/rand" "github.com/rudderlabs/rudder-server/utils/timeutil" @@ -242,3 +250,214 @@ func EnhanceWithDefaultEnvs(t testing.TB) { t.Setenv("LOG_LEVEL", "DEBUG") } } + +func UploadLoadFile( + t testing.TB, + fm filemanager.FileManager, + fileName string, + tableName string, +) filemanager.UploadedFile { + t.Helper() + + f, err := os.Open(fileName) + require.NoError(t, err) + defer func() { _ = f.Close() }() + + loadObjectFolder := "rudder-warehouse-load-objects" + sourceID := "test_source_id" + + uploadOutput, err := fm.Upload( + context.Background(), f, loadObjectFolder, + tableName, sourceID, uuid.New().String()+"-"+tableName, + ) + require.NoError(t, err) + + return uploadOutput +} + +// RetrieveRecordsFromWarehouse retrieves records from the warehouse based on the given query. +// It returns a slice of slices, where each inner slice represents a record's values. +func RetrieveRecordsFromWarehouse( + t testing.TB, + db *sql.DB, + query string, +) [][]string { + t.Helper() + + rows, err := db.QueryContext(context.Background(), query) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + + _ = rows.Err() + + columns, err := rows.Columns() + require.NoError(t, err) + + var records [][]string + for rows.Next() { + resultSet := make([]any, len(columns)) + resultSetPtrs := make([]any, len(columns)) + for i := 0; i < len(columns); i++ { + resultSetPtrs[i] = &resultSet[i] + } + + err = rows.Scan(resultSetPtrs...) + require.NoError(t, err) + + records = append(records, lo.Map(resultSet, func(item any, index int) string { + switch item := item.(type) { + case time.Time: + return item.Format(time.RFC3339) + default: + return cast.ToString(item) + } + })) + } + return records +} + +// SampleTestRecords returns a set of records for testing default loading scenarios. +// It uses testdata/load.* as the source of data. +func SampleTestRecords() [][]string { + return [][]string{ + {"6734e5db-f918-4efe-1421-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "125", ""}, + {"6734e5db-f918-4efe-2314-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "125.75", "", ""}, + {"6734e5db-f918-4efe-2352-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + {"6734e5db-f918-4efe-2414-872f66e235c5", "2022-12-15T06:53:49Z", "false", "2022-12-15T06:53:49Z", "126.75", "126", "hello-world"}, + {"6734e5db-f918-4efe-3555-872f66e235c5", "2022-12-15T06:53:49Z", "false", "", "", "", ""}, + {"6734e5db-f918-4efe-5152-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "hello-world"}, + {"6734e5db-f918-4efe-5323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"7274e5db-f918-4efe-1212-872f66e235c5", "2022-12-15T06:53:49Z", "true", "2022-12-15T06:53:49Z", "125.75", "125", "hello-world"}, + {"7274e5db-f918-4efe-1454-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "125", ""}, + {"7274e5db-f918-4efe-1511-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"7274e5db-f918-4efe-2323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "125.75", "", ""}, + {"7274e5db-f918-4efe-4524-872f66e235c5", "2022-12-15T06:53:49Z", "true", "", "", "", ""}, + {"7274e5db-f918-4efe-5151-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "hello-world"}, + {"7274e5db-f918-4efe-5322-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + } +} + +// AppendTestRecords returns a set of records for testing append scenarios. +// It uses testdata/load.* twice as the source of data. +func AppendTestRecords() [][]string { + return [][]string{ + {"6734e5db-f918-4efe-1421-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "125", ""}, + {"6734e5db-f918-4efe-1421-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "125", ""}, + {"6734e5db-f918-4efe-2314-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "125.75", "", ""}, + {"6734e5db-f918-4efe-2314-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "125.75", "", ""}, + {"6734e5db-f918-4efe-2352-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + {"6734e5db-f918-4efe-2352-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + {"6734e5db-f918-4efe-2414-872f66e235c5", "2022-12-15T06:53:49Z", "false", "2022-12-15T06:53:49Z", "126.75", "126", "hello-world"}, + {"6734e5db-f918-4efe-2414-872f66e235c5", "2022-12-15T06:53:49Z", "false", "2022-12-15T06:53:49Z", "126.75", "126", "hello-world"}, + {"6734e5db-f918-4efe-3555-872f66e235c5", "2022-12-15T06:53:49Z", "false", "", "", "", ""}, + {"6734e5db-f918-4efe-3555-872f66e235c5", "2022-12-15T06:53:49Z", "false", "", "", "", ""}, + {"6734e5db-f918-4efe-5152-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "hello-world"}, + {"6734e5db-f918-4efe-5152-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "hello-world"}, + {"6734e5db-f918-4efe-5323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"6734e5db-f918-4efe-5323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"7274e5db-f918-4efe-1212-872f66e235c5", "2022-12-15T06:53:49Z", "true", "2022-12-15T06:53:49Z", "125.75", "125", "hello-world"}, + {"7274e5db-f918-4efe-1212-872f66e235c5", "2022-12-15T06:53:49Z", "true", "2022-12-15T06:53:49Z", "125.75", "125", "hello-world"}, + {"7274e5db-f918-4efe-1454-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "125", ""}, + {"7274e5db-f918-4efe-1454-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "125", ""}, + {"7274e5db-f918-4efe-1511-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"7274e5db-f918-4efe-1511-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"7274e5db-f918-4efe-2323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "125.75", "", ""}, + {"7274e5db-f918-4efe-2323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "125.75", "", ""}, + {"7274e5db-f918-4efe-4524-872f66e235c5", "2022-12-15T06:53:49Z", "true", "", "", "", ""}, + {"7274e5db-f918-4efe-4524-872f66e235c5", "2022-12-15T06:53:49Z", "true", "", "", "", ""}, + {"7274e5db-f918-4efe-5151-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "hello-world"}, + {"7274e5db-f918-4efe-5151-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "hello-world"}, + {"7274e5db-f918-4efe-5322-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + {"7274e5db-f918-4efe-5322-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + } +} + +// DiscardTestRecords returns a set of records for testing rudder discards. +// It uses testdata/discards.* as the source of data. +func DiscardTestRecords() [][]string { + return [][]string{ + {"context_screen_density", "125.75", "2022-12-15T06:53:49Z", "1", "test_table", "2022-12-15T06:53:49Z"}, + {"context_screen_density", "125", "2022-12-15T06:53:49Z", "2", "test_table", "2022-12-15T06:53:49Z"}, + {"context_screen_density", "true", "2022-12-15T06:53:49Z", "3", "test_table", "2022-12-15T06:53:49Z"}, + {"context_screen_density", "7274e5db-f918-4efe-1212-872f66e235c5", "2022-12-15T06:53:49Z", "4", "test_table", "2022-12-15T06:53:49Z"}, + {"context_screen_density", "hello-world", "2022-12-15T06:53:49Z", "5", "test_table", "2022-12-15T06:53:49Z"}, + {"context_screen_density", "2022-12-15T06:53:49.640Z", "2022-12-15T06:53:49Z", "6", "test_table", "2022-12-15T06:53:49Z"}, + } +} + +// DedupTestRecords returns a set of records for testing deduplication scenarios. +// It uses testdata/dedup.* as the source of data. +func DedupTestRecords() [][]string { + return [][]string{ + {"6734e5db-f918-4efe-1421-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "521", ""}, + {"6734e5db-f918-4efe-2314-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "75.125", "", ""}, + {"6734e5db-f918-4efe-2352-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + {"6734e5db-f918-4efe-2414-872f66e235c5", "2022-12-15T06:53:49Z", "true", "2022-12-15T06:53:49Z", "75.125", "521", "world-hello"}, + {"6734e5db-f918-4efe-3555-872f66e235c5", "2022-12-15T06:53:49Z", "true", "", "", "", ""}, + {"6734e5db-f918-4efe-5152-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "world-hello"}, + {"6734e5db-f918-4efe-5323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"7274e5db-f918-4efe-1212-872f66e235c5", "2022-12-15T06:53:49Z", "false", "2022-12-15T06:53:49Z", "75.125", "521", "world-hello"}, + {"7274e5db-f918-4efe-1454-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "521", ""}, + {"7274e5db-f918-4efe-1511-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"7274e5db-f918-4efe-2323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "75.125", "", ""}, + {"7274e5db-f918-4efe-4524-872f66e235c5", "2022-12-15T06:53:49Z", "false", "", "", "", ""}, + {"7274e5db-f918-4efe-5151-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "world-hello"}, + {"7274e5db-f918-4efe-5322-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + } +} + +// DedupTwiceTestRecords returns a set of records for testing deduplication scenarios. +// It uses testdata/dedup.* as the source of data. +func DedupTwiceTestRecords() [][]string { + return [][]string{ + {"6734e5db-f918-4efe-1421-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "521", ""}, + {"6734e5db-f918-4efe-1421-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "521", ""}, + {"6734e5db-f918-4efe-2314-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "75.125", "", ""}, + {"6734e5db-f918-4efe-2314-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "75.125", "", ""}, + {"6734e5db-f918-4efe-2352-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + {"6734e5db-f918-4efe-2352-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + {"6734e5db-f918-4efe-2414-872f66e235c5", "2022-12-15T06:53:49Z", "true", "2022-12-15T06:53:49Z", "75.125", "521", "world-hello"}, + {"6734e5db-f918-4efe-2414-872f66e235c5", "2022-12-15T06:53:49Z", "true", "2022-12-15T06:53:49Z", "75.125", "521", "world-hello"}, + {"6734e5db-f918-4efe-3555-872f66e235c5", "2022-12-15T06:53:49Z", "true", "", "", "", ""}, + {"6734e5db-f918-4efe-3555-872f66e235c5", "2022-12-15T06:53:49Z", "true", "", "", "", ""}, + {"6734e5db-f918-4efe-5152-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "world-hello"}, + {"6734e5db-f918-4efe-5152-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "world-hello"}, + {"6734e5db-f918-4efe-5323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"6734e5db-f918-4efe-5323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"7274e5db-f918-4efe-1212-872f66e235c5", "2022-12-15T06:53:49Z", "false", "2022-12-15T06:53:49Z", "75.125", "521", "world-hello"}, + {"7274e5db-f918-4efe-1212-872f66e235c5", "2022-12-15T06:53:49Z", "false", "2022-12-15T06:53:49Z", "75.125", "521", "world-hello"}, + {"7274e5db-f918-4efe-1454-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "521", ""}, + {"7274e5db-f918-4efe-1454-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "521", ""}, + {"7274e5db-f918-4efe-1511-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"7274e5db-f918-4efe-1511-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"7274e5db-f918-4efe-2323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "75.125", "", ""}, + {"7274e5db-f918-4efe-2323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "75.125", "", ""}, + {"7274e5db-f918-4efe-4524-872f66e235c5", "2022-12-15T06:53:49Z", "false", "", "", "", ""}, + {"7274e5db-f918-4efe-4524-872f66e235c5", "2022-12-15T06:53:49Z", "false", "", "", "", ""}, + {"7274e5db-f918-4efe-5151-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "world-hello"}, + {"7274e5db-f918-4efe-5151-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "world-hello"}, + {"7274e5db-f918-4efe-5322-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + {"7274e5db-f918-4efe-5322-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + } +} + +// MismatchSchemaTestRecords returns a set of records for testing schema mismatch scenarios. +// It uses testdata/mismatch-schema.* as the source of data. +func MismatchSchemaTestRecords() [][]string { + return [][]string{ + {"6734e5db-f918-4efe-1421-872f66e235c5", "", "", "", "", "", ""}, + {"6734e5db-f918-4efe-2314-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "125.75", "", ""}, + {"6734e5db-f918-4efe-2352-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + {"6734e5db-f918-4efe-2414-872f66e235c5", "2022-12-15T06:53:49Z", "false", "2022-12-15T06:53:49Z", "126.75", "126", "hello-world"}, + {"6734e5db-f918-4efe-3555-872f66e235c5", "2022-12-15T06:53:49Z", "false", "", "", "", ""}, + {"6734e5db-f918-4efe-5152-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "hello-world"}, + {"6734e5db-f918-4efe-5323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"7274e5db-f918-4efe-1212-872f66e235c5", "2022-12-15T06:53:49Z", "true", "2022-12-15T06:53:49Z", "125.75", "125", "hello-world"}, + {"7274e5db-f918-4efe-1454-872f66e235c5", "", "", "", "", "", ""}, + {"7274e5db-f918-4efe-1511-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", ""}, + {"7274e5db-f918-4efe-2323-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "125.75", "", ""}, + {"7274e5db-f918-4efe-4524-872f66e235c5", "2022-12-15T06:53:49Z", "true", "", "", "", ""}, + {"7274e5db-f918-4efe-5151-872f66e235c5", "2022-12-15T06:53:49Z", "", "", "", "", "hello-world"}, + {"7274e5db-f918-4efe-5322-872f66e235c5", "2022-12-15T06:53:49Z", "", "2022-12-15T06:53:49Z", "", "", ""}, + } +} diff --git a/warehouse/integrations/types/types.go b/warehouse/integrations/types/types.go new file mode 100644 index 0000000000..1f994b5b40 --- /dev/null +++ b/warehouse/integrations/types/types.go @@ -0,0 +1,6 @@ +package types + +type LoadTableStats struct { + RowsInserted int64 + RowsUpdated int64 +} diff --git a/warehouse/internal/model/schema.go b/warehouse/internal/model/schema.go index fff0ec0749..4855ab54a3 100644 --- a/warehouse/internal/model/schema.go +++ b/warehouse/internal/model/schema.go @@ -3,7 +3,7 @@ package model import "time" type ( - SchemaType string + SchemaType = string TableSchema map[string]string Schema map[string]TableSchema ) @@ -16,6 +16,7 @@ const ( FloatDataType SchemaType = "float" JSONDataType SchemaType = "json" TextDataType SchemaType = "text" + DateTimeDataType SchemaType = "datetime" ArrayOfBooleanDatatype SchemaType = "array(boolean)" ) diff --git a/warehouse/logfield/logfield.go b/warehouse/logfield/logfield.go index 66e140ff87..921ca64ac3 100644 --- a/warehouse/logfield/logfield.go +++ b/warehouse/logfield/logfield.go @@ -18,6 +18,7 @@ const ( TableName = "tableName" ColumnName = "columnName" ColumnType = "columnType" + ColumnValue = "columnValue" Priority = "priority" Retried = "retried" Attempt = "attempt" diff --git a/warehouse/schema/schema.go b/warehouse/schema/schema.go index a6154e7dd6..e098948d2e 100644 --- a/warehouse/schema/schema.go +++ b/warehouse/schema/schema.go @@ -294,8 +294,8 @@ func consolidateStagingSchemas(consolidatedSchema model.Schema, schemas []model. consolidatedSchema[tableName] = model.TableSchema{} } for columnName, columnType := range columnMap { - if model.SchemaType(columnType) == model.TextDataType { - consolidatedSchema[tableName][columnName] = string(model.TextDataType) + if columnType == model.TextDataType { + consolidatedSchema[tableName][columnName] = model.TextDataType continue } @@ -322,8 +322,8 @@ func consolidateWarehouseSchema(consolidatedSchema, warehouseSchema model.Schema } var ( - consolidatedSchemaType = model.SchemaType(consolidatedSchema[tableName][columnName]) - warehouseSchemaType = model.SchemaType(columnType) + consolidatedSchemaType = consolidatedSchema[tableName][columnName] + warehouseSchemaType = columnType ) if consolidatedSchemaType == model.TextDataType && warehouseSchemaType == model.StringDataType { diff --git a/warehouse/slave_worker.go b/warehouse/slave_worker.go index cbb9b55fa9..f8c54423e2 100644 --- a/warehouse/slave_worker.go +++ b/warehouse/slave_worker.go @@ -308,7 +308,7 @@ func (sw *slaveWorker) processStagingFile(ctx context.Context, job payload) ([]u if job.DestinationType == warehouseutils.CLICKHOUSE { switch columnType { - case string(model.BooleanDataType): + case model.BooleanDataType: newColumnVal := 0 if k, ok := columnVal.(bool); ok { @@ -318,7 +318,7 @@ func (sw *slaveWorker) processStagingFile(ctx context.Context, job payload) ([]u } columnVal = newColumnVal - case string(model.ArrayOfBooleanDatatype): + case model.ArrayOfBooleanDatatype: if boolValue, ok := columnVal.([]interface{}); ok { newColumnVal := make([]interface{}, len(boolValue)) @@ -335,7 +335,7 @@ func (sw *slaveWorker) processStagingFile(ctx context.Context, job payload) ([]u } } - if model.SchemaType(columnType) == model.IntDataType || model.SchemaType(columnType) == model.BigIntDataType { + if columnType == model.IntDataType || columnType == model.BigIntDataType { floatVal, ok := columnVal.(float64) if !ok { eventLoader.AddEmptyColumn(columnName) @@ -350,8 +350,8 @@ func (sw *slaveWorker) processStagingFile(ctx context.Context, job payload) ([]u if ok && ((columnType != dataTypeInSchema) || (violatedConstraints.isViolated)) { newColumnVal, convError := handleSchemaChange( - model.SchemaType(dataTypeInSchema), - model.SchemaType(columnType), + dataTypeInSchema, + columnType, columnVal, ) diff --git a/warehouse/slave_worker_test.go b/warehouse/slave_worker_test.go index d347409c61..42a1803e51 100644 --- a/warehouse/slave_worker_test.go +++ b/warehouse/slave_worker_test.go @@ -970,8 +970,8 @@ func TestHandleSchemaChange(t *testing.T) { t.Parallel() newColumnVal, convError := handleSchemaChange( - model.SchemaType(tc.existingDatatype), - model.SchemaType(tc.currentDataType), + tc.existingDatatype, + tc.currentDataType, tc.value, ) require.Equal(t, newColumnVal, tc.newColumnVal) diff --git a/warehouse/upload.go b/warehouse/upload.go index 69cc9a5bfa..1b90352a43 100644 --- a/warehouse/upload.go +++ b/warehouse/upload.go @@ -107,20 +107,17 @@ type UploadJob struct { pendingTableUploadsError error config struct { - refreshPartitionBatchSize int - retryTimeWindow time.Duration - minRetryAttempts int - disableAlter bool - minUploadBackoff time.Duration - maxUploadBackoff time.Duration - alwaysRegenerateAllLoadFiles bool - reportingEnabled bool - generateTableLoadCountMetrics bool - disableGenerateTableLoadCountMetricsWorkspaceIDs []string - maxParallelLoadsWorkspaceIDs map[string]interface{} - columnsBatchSize int - longRunningUploadStatThresholdInMin time.Duration - tableCountQueryTimeout time.Duration + refreshPartitionBatchSize int + retryTimeWindow time.Duration + minRetryAttempts int + disableAlter bool + minUploadBackoff time.Duration + maxUploadBackoff time.Duration + alwaysRegenerateAllLoadFiles bool + reportingEnabled bool + maxParallelLoadsWorkspaceIDs map[string]interface{} + columnsBatchSize int + longRunningUploadStatThresholdInMin time.Duration } errorHandler ErrorHandler @@ -219,11 +216,8 @@ func (f *UploadJobFactory) NewUploadJob(ctx context.Context, dto *model.UploadJo uj.config.disableAlter = f.conf.GetBool("Warehouse.disableAlter", false) uj.config.alwaysRegenerateAllLoadFiles = f.conf.GetBool("Warehouse.alwaysRegenerateAllLoadFiles", true) uj.config.reportingEnabled = f.conf.GetBool("Reporting.enabled", types.DefaultReportingEnabled) - uj.config.generateTableLoadCountMetrics = f.conf.GetBool("Warehouse.generateTableLoadCountMetrics", true) - uj.config.disableGenerateTableLoadCountMetricsWorkspaceIDs = f.conf.GetStringSlice("Warehouse.disableGenerateTableLoadCountMetricsWorkspaceIDs", nil) uj.config.columnsBatchSize = f.conf.GetInt(fmt.Sprintf("Warehouse.%s.columnsBatchSize", whutils.WHDestNameMap[uj.upload.DestinationType]), 100) uj.config.maxParallelLoadsWorkspaceIDs = f.conf.GetStringMap(fmt.Sprintf("Warehouse.%s.maxParallelLoadsWorkspaceIDs", whutils.WHDestNameMap[uj.upload.DestinationType]), nil) - uj.config.tableCountQueryTimeout = f.conf.GetDurationVar(30, time.Second, "Warehouse.tableCountQueryTimeout", "Warehouse.tableCountQueryTimeoutInS") uj.config.longRunningUploadStatThresholdInMin = f.conf.GetDurationVar(120, time.Minute, "Warehouse.longRunningUploadStatThreshold", "Warehouse.longRunningUploadStatThresholdInMin") uj.config.minUploadBackoff = f.conf.GetDurationVar(60, time.Second, "Warehouse.minUploadBackoff", "Warehouse.minUploadBackoffInS") uj.config.maxUploadBackoff = f.conf.GetDurationVar(1800, time.Second, "Warehouse.maxUploadBackoff", "Warehouse.maxUploadBackoffInS") @@ -1014,30 +1008,6 @@ func (job *UploadJob) updateSchema(tName string) (alteredSchema bool, err error) return } -func (job *UploadJob) getTotalCount(tName string) (int64, error) { - var ( - total int64 - countErr error - ) - - operation := func() error { - ctx, cancel := context.WithTimeout(job.ctx, job.config.tableCountQueryTimeout) - defer cancel() - - total, countErr = job.whManager.GetTotalCountInTable(ctx, tName) - return countErr - } - - expBackoff := backoff.NewExponentialBackOff() - expBackoff.InitialInterval = 5 * time.Second - expBackoff.RandomizationFactor = 0 - expBackoff.Reset() - - backoffWithMaxRetry := backoff.WithMaxRetries(expBackoff, 5) - err := backoff.Retry(operation, backoffWithMaxRetry) - return total, err -} - func (job *UploadJob) loadTable(tName string) (bool, error) { alteredSchema, err := job.updateSchema(tName) if err != nil { @@ -1068,28 +1038,7 @@ func (job *UploadJob) loadTable(tName string) (bool, error) { LastExecTime: &lastExecTime, }) - generateTableLoadCountVerificationsMetrics := job.config.generateTableLoadCountMetrics - if slices.Contains(job.config.disableGenerateTableLoadCountMetricsWorkspaceIDs, job.upload.WorkspaceID) { - generateTableLoadCountVerificationsMetrics = false - } - - var totalBeforeLoad, totalAfterLoad int64 - if generateTableLoadCountVerificationsMetrics { - var errTotalCount error - totalBeforeLoad, errTotalCount = job.getTotalCount(tName) - if errTotalCount != nil { - job.logger.Warnw("total count in table before loading", - logfield.SourceID, job.upload.SourceID, - logfield.DestinationID, job.upload.DestinationID, - logfield.DestinationType, job.upload.DestinationType, - logfield.WorkspaceID, job.upload.WorkspaceID, - logfield.Error, errTotalCount, - logfield.TableName, tName, - ) - } - } - - err = job.whManager.LoadTable(job.ctx, tName) + loadTableStat, err := job.whManager.LoadTable(job.ctx, tName) if err != nil { status := model.TableUploadExportingFailed errorsString := misc.QuoteLiteral(err.Error()) @@ -1099,34 +1048,25 @@ func (job *UploadJob) loadTable(tName string) (bool, error) { }) return alteredSchema, fmt.Errorf("load table: %w", err) } + if loadTableStat.RowsUpdated > 0 { + job.statsFactory.NewTaggedStat("dedup_rows", stats.CountType, stats.Tags{ + "sourceID": job.warehouse.Source.ID, + "sourceType": job.warehouse.Source.SourceDefinition.Name, + "sourceCategory": job.warehouse.Source.SourceDefinition.Category, + "destID": job.warehouse.Destination.ID, + "destType": job.warehouse.Destination.DestinationDefinition.Name, + "workspaceId": job.warehouse.WorkspaceID, + "tableName": tName, + }).Count(int(loadTableStat.RowsUpdated)) + } - func() { - if !generateTableLoadCountVerificationsMetrics { - return - } - var errTotalCount error - totalAfterLoad, errTotalCount = job.getTotalCount(tName) - if errTotalCount != nil { - job.logger.Warnw("total count in table after loading", - logfield.SourceID, job.upload.SourceID, - logfield.DestinationID, job.upload.DestinationID, - logfield.DestinationType, job.upload.DestinationType, - logfield.WorkspaceID, job.upload.WorkspaceID, - logfield.Error, errTotalCount, - logfield.TableName, tName, - ) - return - } - tableUpload, errEventCount := job.tableUploadsRepo.GetByUploadIDAndTableName(job.ctx, job.upload.ID, tName) - if errEventCount != nil { - return - } + tableUpload, errEventCount := job.tableUploadsRepo.GetByUploadIDAndTableName(job.ctx, job.upload.ID, tName) + if errEventCount != nil { + return alteredSchema, fmt.Errorf("get table upload: %w", errEventCount) + } - // TODO : Perform the comparison here in the codebase - job.guageStat(`pre_load_table_rows`, whutils.Tag{Name: "tableName", Value: strings.ToLower(tName)}).Gauge(int(totalBeforeLoad)) - job.guageStat(`post_load_table_rows_estimate`, whutils.Tag{Name: "tableName", Value: strings.ToLower(tName)}).Gauge(int(totalBeforeLoad + tableUpload.TotalEvents)) - job.guageStat(`post_load_table_rows`, whutils.Tag{Name: "tableName", Value: strings.ToLower(tName)}).Gauge(int(totalAfterLoad)) - }() + job.guageStat(`post_load_table_rows_estimate`, whutils.Tag{Name: "tableName", Value: strings.ToLower(tName)}).Gauge(int(tableUpload.TotalEvents)) + job.guageStat(`post_load_table_rows`, whutils.Tag{Name: "tableName", Value: strings.ToLower(tName)}).Gauge(int(loadTableStat.RowsInserted)) status = model.TableUploadExported _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, tName, repo.TableUploadSetOptions{