diff --git a/clients/databricks/dialect/dialect.go b/clients/databricks/dialect/dialect.go index 57021cbbb..2282b0074 100644 --- a/clients/databricks/dialect/dialect.go +++ b/clients/databricks/dialect/dialect.go @@ -67,6 +67,8 @@ func (d DatabricksDialect) KindForDataType(_type string, _ string) (typing.KindD return typing.Boolean, nil case "VARIANT": return typing.Struct, nil + case "TIMESTAMP": + return typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), nil } return typing.KindDetails{}, fmt.Errorf("unsupported data type: %q", _type) @@ -81,11 +83,7 @@ func (DatabricksDialect) IsTableDoesNotExistErr(err error) bool { } func (d DatabricksDialect) BuildCreateTableQuery(tableID sql.TableIdentifier, temporary bool, colSQLParts []string) string { - temp := "" - if temporary { - temp = "TEMPORARY " - } - return fmt.Sprintf("CREATE %sTABLE %s (%s)", temp, tableID.FullyQualifiedName(), strings.Join(colSQLParts, ", ")) + return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", tableID.FullyQualifiedName(), strings.Join(colSQLParts, ", ")) } func (d DatabricksDialect) BuildAlterColumnQuery(tableID sql.TableIdentifier, columnOp constants.ColumnOperation, colSQLPart string) string { diff --git a/clients/databricks/store.go b/clients/databricks/store.go index b453c8501..9a5a38967 100644 --- a/clients/databricks/store.go +++ b/clients/databricks/store.go @@ -1,9 +1,17 @@ package databricks import ( + "context" + "encoding/csv" "fmt" + "log/slog" + "os" + "path/filepath" - _ "github.com/databricks/databricks-sql-go" + "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/destination/ddl" + "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/values" "github.com/artie-labs/transfer/clients/databricks/dialect" "github.com/artie-labs/transfer/clients/shared" @@ -13,6 +21,8 @@ import ( "github.com/artie-labs/transfer/lib/kafkalib" "github.com/artie-labs/transfer/lib/optimization" "github.com/artie-labs/transfer/lib/sql" + _ "github.com/databricks/databricks-sql-go" + driverctx "github.com/databricks/databricks-sql-go/driverctx" ) type Store struct { @@ -22,7 +32,12 @@ type Store struct { } func describeTableQuery(tableID TableIdentifier) (string, []any) { - return fmt.Sprintf("DESCRIBE TABLE %s.%s.%s", tableID.Database(), tableID.Schema(), tableID.Table()), nil + _dialect := dialect.DatabricksDialect{} + return fmt.Sprintf("DESCRIBE TABLE %s.%s.%s", + _dialect.QuoteIdentifier(tableID.Database()), + _dialect.QuoteIdentifier(tableID.Schema()), + _dialect.QuoteIdentifier(tableID.Table()), + ), nil } func (s Store) Merge(tableData *optimization.TableData) error { @@ -54,15 +69,111 @@ func (s Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTabl ConfigMap: s.configMap, Query: query, Args: args, - ColumnNameForName: "column_name", + ColumnNameForName: "col_name", ColumnNameForDataType: "data_type", - ColumnNameForComment: "description", + ColumnNameForComment: "comment", DropDeletedColumns: tableData.TopicConfig().DropDeletedColumns, }.GetTableConfig() } -func (s Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, parentTableID sql.TableIdentifier, additionalSettings types.AdditionalSettings, createTempTable bool) error { - panic("not implemented") +func (s Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error { + if createTempTable { + tempAlterTableArgs := ddl.AlterTableArgs{ + Dialect: s.Dialect(), + Tc: tableConfig, + TableID: tempTableID, + CreateTable: true, + TemporaryTable: true, + ColumnOp: constants.Add, + Mode: tableData.Mode(), + } + + if err := tempAlterTableArgs.AlterTable(s, tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil { + return fmt.Errorf("failed to create temp table: %w", err) + } + } + + // Write data into a temporary file + fp, err := s.writeTemporaryTableFile(tableData, tempTableID) + if err != nil { + return fmt.Errorf("failed to load temporary table: %w", err) + } + + defer func() { + // In the case where PUT or COPY fails, we'll at least delete the temporary file. + if deleteErr := os.RemoveAll(fp); deleteErr != nil { + slog.Warn("Failed to delete temp file", slog.Any("err", deleteErr), slog.String("filePath", fp)) + } + }() + + castedTempTableID, isOk := tempTableID.(TableIdentifier) + if !isOk { + return fmt.Errorf("failed to cast tempTableID to TableIdentifier") + } + + dbfsFilePath := fmt.Sprintf("dbfs:/Volumes/%s/%s/vol_test/%s.csv", castedTempTableID.Database(), castedTempTableID.Schema(), tempTableID.Table()) + + ctx := driverctx.NewContextWithStagingInfo(context.Background(), []string{"/var"}) + + // Use the PUT INTO command to upload the file to Databricks + putCommand := fmt.Sprintf("PUT '%s' INTO '%s' OVERWRITE", fp, dbfsFilePath) + if _, err = s.ExecContext(ctx, putCommand); err != nil { + return fmt.Errorf("failed to run PUT INTO for temporary table: %w", err) + } + + // Use the COPY INTO command to load the data into the temporary table + copyCommand := fmt.Sprintf("COPY INTO %s BY POSITION FROM '%s' FILEFORMAT = CSV FORMAT_OPTIONS ('delimiter' = '\t', 'header' = 'false')", tempTableID.FullyQualifiedName(), dbfsFilePath) + if _, err = s.ExecContext(ctx, copyCommand); err != nil { + return fmt.Errorf("failed to run COPY INTO for temporary table: %w", err) + } + + return nil +} + +func castColValStaging(colVal any, colKind typing.KindDetails) (string, error) { + if colVal == nil { + // \\N needs to match NULL_IF(...) from ddl.go + return `\\N`, nil + } + + value, err := values.ToString(colVal, colKind) + if err != nil { + return "", err + } + + return value, nil +} + +func (s Store) writeTemporaryTableFile(tableData *optimization.TableData, newTableID sql.TableIdentifier) (string, error) { + fp := filepath.Join(os.TempDir(), fmt.Sprintf("%s.csv", newTableID.FullyQualifiedName())) + file, err := os.Create(fp) + if err != nil { + return "", err + } + + defer file.Close() + writer := csv.NewWriter(file) + writer.Comma = '\t' + + columns := tableData.ReadOnlyInMemoryCols().ValidColumns() + for _, value := range tableData.Rows() { + var row []string + for _, col := range columns { + castedValue, castErr := castColValStaging(value[col.Name()], col.KindDetails) + if castErr != nil { + return "", castErr + } + + row = append(row, castedValue) + } + + if err = writer.Write(row); err != nil { + return "", fmt.Errorf("failed to write to csv: %w", err) + } + } + + writer.Flush() + return fp, writer.Error() } func LoadStore(cfg config.Config) (Store, error) { diff --git a/clients/databricks/tableid.go b/clients/databricks/tableid.go index fcb6375f6..df1313063 100644 --- a/clients/databricks/tableid.go +++ b/clients/databricks/tableid.go @@ -44,5 +44,5 @@ func (ti TableIdentifier) WithTable(table string) sql.TableIdentifier { } func (ti TableIdentifier) FullyQualifiedName() string { - return fmt.Sprintf("%s.%s.%s", ti.database, ti.schema, ti.EscapedTable()) + return fmt.Sprintf("%s.%s.%s", _dialect.QuoteIdentifier(ti.database), _dialect.QuoteIdentifier(ti.schema), ti.EscapedTable()) } diff --git a/lib/config/destination_types.go b/lib/config/destination_types.go index 44d16a3a0..de258ae13 100644 --- a/lib/config/destination_types.go +++ b/lib/config/destination_types.go @@ -59,5 +59,4 @@ type Databricks struct { Database string `yaml:"database"` Protocol string `yaml:"protocol"` CatalogName string `yaml:"catalogName"` - SchemaName string `yaml:"schemaName"` } diff --git a/lib/config/destinations.go b/lib/config/destinations.go index 81bb6bd18..45bf51de3 100644 --- a/lib/config/destinations.go +++ b/lib/config/destinations.go @@ -77,7 +77,6 @@ func (s Snowflake) ToConfig() (*gosnowflake.Config, error) { func (d Databricks) DSN() string { query := url.Values{} query.Add("catalog", d.CatalogName) - query.Add("schema", d.SchemaName) u := &url.URL{ Path: "/sql/1.0/warehouses/cab738c29ff77d72", User: url.UserPassword("token", d.PersonalAccessToken), diff --git a/lib/db/db.go b/lib/db/db.go index 49cbd977f..529677afc 100644 --- a/lib/db/db.go +++ b/lib/db/db.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "fmt" "log/slog" @@ -15,8 +16,9 @@ const ( ) type Store interface { - Exec(query string, args ...any) (sql.Result, error) Query(query string, args ...any) (*sql.Rows, error) + Exec(query string, args ...any) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) Begin() (*sql.Tx, error) IsRetryableError(err error) bool } @@ -25,6 +27,28 @@ type storeWrapper struct { *sql.DB } +func (s *storeWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + var result sql.Result + var err error + for attempts := 0; attempts < maxAttempts; attempts++ { + if attempts > 0 { + sleepDuration := jitter.Jitter(sleepBaseMs, jitter.DefaultMaxMs, attempts-1) + slog.Warn("Failed to execute the query, retrying...", + slog.Any("err", err), + slog.Duration("sleep", sleepDuration), + slog.Int("attempts", attempts), + ) + time.Sleep(sleepDuration) + } + + result, err = s.DB.ExecContext(ctx, query, args...) + if err == nil || !s.IsRetryableError(err) { + break + } + } + return result, err +} + func (s *storeWrapper) Exec(query string, args ...any) (sql.Result, error) { var result sql.Result var err error