Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call create table in more places. #1034

Merged
merged 7 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions clients/databricks/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/artie-labs/transfer/clients/databricks/dialect"
"github.com/artie-labs/transfer/clients/shared"
"github.com/artie-labs/transfer/lib/config"
"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/db"
"github.com/artie-labs/transfer/lib/destination/ddl"
"github.com/artie-labs/transfer/lib/destination/types"
Expand Down Expand Up @@ -81,19 +80,14 @@ func (s Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTabl
}.GetTableConfig()
}

func (s Store) PrepareTemporaryTable(ctx context.Context, tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
func (s Store) PrepareTemporaryTable(ctx context.Context, tableData *optimization.TableData, _ *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dialect: s.Dialect(),
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
Mode: tableData.Mode(),
query, err := ddl.BuildCreateTableSQL(s.Dialect(), tempTableID, true, tableData.Mode(), tableData.ReadOnlyInMemoryCols().GetColumns())
if err != nil {
return fmt.Errorf("failed to build create table sql: %w", err)
}

if err := tempAlterTableArgs.AlterTable(s, tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil {
if _, err = s.ExecContext(ctx, query); err != nil {
return fmt.Errorf("failed to create temp table: %w", err)
}
}
Expand Down
16 changes: 5 additions & 11 deletions clients/mssql/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,21 @@ import (

mssql "github.com/microsoft/go-mssqldb"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/destination/ddl"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing/columns"
)

func (s *Store) PrepareTemporaryTable(_ context.Context, tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimization.TableData, _ *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dialect: s.Dialect(),
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
Mode: tableData.Mode(),
query, err := ddl.BuildCreateTableSQL(s.Dialect(), tempTableID, true, tableData.Mode(), tableData.ReadOnlyInMemoryCols().GetColumns())
if err != nil {
return fmt.Errorf("failed to build create table sql: %w", err)
}

if err := tempAlterTableArgs.AlterTable(s, tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil {
if _, err = s.ExecContext(ctx, query); err != nil {
return fmt.Errorf("failed to create temp table: %w", err)
}
}
Expand Down
14 changes: 4 additions & 10 deletions clients/redshift/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"os"
"strings"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/destination/ddl"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/optimization"
Expand Down Expand Up @@ -41,17 +40,12 @@ func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimizati
}

if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dialect: s.Dialect(),
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
Mode: tableData.Mode(),
query, err := ddl.BuildCreateTableSQL(s.Dialect(), tempTableID, true, tableData.Mode(), tableData.ReadOnlyInMemoryCols().GetColumns())
if err != nil {
return fmt.Errorf("failed to build create table sql: %w", err)
}

if err = tempAlterTableArgs.AlterTable(s, tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil {
if _, err = s.ExecContext(ctx, query); err != nil {
return fmt.Errorf("failed to create temp table: %w", err)
}
}
Expand Down
27 changes: 14 additions & 13 deletions clients/snowflake/snowflake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ func (s *SnowflakeTestSuite) TestExecuteMergeReestablishAuth() {

s.stageStore.configMap.AddTableToConfig(s.identifierFor(tableData), types.NewDwhTableConfig(cols.GetColumns(), true))
assert.NoError(s.T(), s.stageStore.Merge(context.Background(), tableData))
assert.Equal(s.T(), 5, s.fakeStageStore.ExecCallCount())
assert.Equal(s.T(), 4, s.fakeStageStore.ExecCallCount())
assert.Equal(s.T(), 1, s.fakeStageStore.ExecContextCallCount())
}

func (s *SnowflakeTestSuite) TestExecuteMerge() {
Expand Down Expand Up @@ -165,36 +166,35 @@ func (s *SnowflakeTestSuite) TestExecuteMerge() {
tableData.InsertRow(pk, row, false)
}

var idx int

tableID := s.identifierFor(tableData)
fqName := tableID.FullyQualifiedName()
s.stageStore.configMap.AddTableToConfig(tableID, types.NewDwhTableConfig(cols.GetColumns(), true))
err := s.stageStore.Merge(context.Background(), tableData)
assert.Nil(s.T(), err)
s.fakeStageStore.ExecReturns(nil, nil)
// CREATE TABLE IF NOT EXISTS customer.public.orders___artie_Mwv9YADmRy (id int,name string,__artie_delete boolean,created_at timestamp_tz) STAGE_COPY_OPTIONS = ( PURGE = TRUE ) STAGE_FILE_FORMAT = ( TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE) COMMENT='expires:2023-06-27 11:54:03 UTC'
createQuery, _ := s.fakeStageStore.ExecArgsForCall(idx)
_, createQuery, _ := s.fakeStageStore.ExecContextArgsForCall(0)
assert.Contains(s.T(), createQuery, `customer.public."ORDERS___ARTIE_`, fmt.Sprintf("query: %v, destKind: %v", createQuery, constants.Snowflake))

// PUT file:///tmp/customer.public.orders___artie_Mwv9YADmRy.csv @customer.public.%orders___artie_Mwv9YADmRy AUTO_COMPRESS=TRUE
putQuery, _ := s.fakeStageStore.ExecArgsForCall(idx + 1)
putQuery, _ := s.fakeStageStore.ExecArgsForCall(0)
assert.Contains(s.T(), putQuery, "PUT file://")

// COPY INTO customer.public.orders___artie_Mwv9YADmRy (id,name,__artie_delete,created_at) FROM (SELECT $1,$2,$3,$4 FROM @customer.public.%orders___artie_Mwv9YADmRy
copyQuery, _ := s.fakeStageStore.ExecArgsForCall(idx + 2)
copyQuery, _ := s.fakeStageStore.ExecArgsForCall(1)
assert.Contains(s.T(), copyQuery, `COPY INTO customer.public."ORDERS___ARTIE_`, fmt.Sprintf("query: %v, destKind: %v", copyQuery, constants.Snowflake))
assert.Contains(s.T(), copyQuery, fmt.Sprintf("FROM %s", "@customer.public.\"%ORDERS___ARTIE_"), fmt.Sprintf("query: %v, destKind: %v", copyQuery, constants.Snowflake))

mergeQuery, _ := s.fakeStageStore.ExecArgsForCall(idx + 3)
mergeQuery, _ := s.fakeStageStore.ExecArgsForCall(2)
assert.Contains(s.T(), mergeQuery, fmt.Sprintf("MERGE INTO %s", fqName), fmt.Sprintf("query: %v, destKind: %v", mergeQuery, constants.Snowflake))

// Drop a table now.
dropQuery, _ := s.fakeStageStore.ExecArgsForCall(idx + 4)
dropQuery, _ := s.fakeStageStore.ExecArgsForCall(3)
assert.Contains(s.T(), dropQuery, `DROP TABLE IF EXISTS customer.public."ORDERS___ARTIE_`,
fmt.Sprintf("query: %v, destKind: %v", dropQuery, constants.Snowflake))

assert.Equal(s.T(), 5, s.fakeStageStore.ExecCallCount(), "called merge")
assert.Equal(s.T(), 4, s.fakeStageStore.ExecCallCount())
assert.Equal(s.T(), 1, s.fakeStageStore.ExecContextCallCount())
}

// TestExecuteMergeDeletionFlagRemoval is going to run execute merge twice.
Expand Down Expand Up @@ -256,7 +256,8 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() {

assert.NoError(s.T(), s.stageStore.Merge(context.Background(), tableData))
s.fakeStageStore.ExecReturns(nil, nil)
assert.Equal(s.T(), s.fakeStageStore.ExecCallCount(), 5, "called merge")
assert.Equal(s.T(), 4, s.fakeStageStore.ExecCallCount())
assert.Equal(s.T(), 1, s.fakeStageStore.ExecContextCallCount())

// Check the temp deletion table now.
assert.Equal(s.T(), len(s.stageStore.configMap.TableConfigCache(s.identifierFor(tableData)).ReadOnlyColumnsToDelete()), 1,
Expand All @@ -279,11 +280,11 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() {

assert.NoError(s.T(), s.stageStore.Merge(context.Background(), tableData))
s.fakeStageStore.ExecReturns(nil, nil)
assert.Equal(s.T(), s.fakeStageStore.ExecCallCount(), 10, "called merge again")
assert.Equal(s.T(), 8, s.fakeStageStore.ExecCallCount())
assert.Equal(s.T(), 2, s.fakeStageStore.ExecContextCallCount())

// Caught up now, so columns should be 0.
assert.Equal(s.T(), len(s.stageStore.configMap.TableConfigCache(s.identifierFor(tableData)).ReadOnlyColumnsToDelete()), 0,
s.stageStore.configMap.TableConfigCache(s.identifierFor(tableData)).ReadOnlyColumnsToDelete())
assert.Len(s.T(), s.stageStore.configMap.TableConfigCache(s.identifierFor(tableData)).ReadOnlyColumnsToDelete(), 0)
}

func (s *SnowflakeTestSuite) TestExecuteMergeExitEarly() {
Expand Down
15 changes: 5 additions & 10 deletions clients/snowflake/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,14 @@ func castColValStaging(colVal any, colKind typing.KindDetails) (string, error) {
return replaceExceededValues(value, colKind), nil
}

func (s *Store) PrepareTemporaryTable(_ context.Context, tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, additionalSettings types.AdditionalSettings, createTempTable bool) error {
func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimization.TableData, _ *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, additionalSettings types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dialect: s.Dialect(),
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
Mode: tableData.Mode(),
query, err := ddl.BuildCreateTableSQL(s.Dialect(), tempTableID, true, tableData.Mode(), tableData.ReadOnlyInMemoryCols().GetColumns())
if err != nil {
return fmt.Errorf("failed to build create table sql: %w", err)
}

if err := tempAlterTableArgs.AlterTable(s, tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil {
if _, err = s.ExecContext(ctx, query); err != nil {
return fmt.Errorf("failed to create temp table: %w", err)
}
}
Expand Down
16 changes: 8 additions & 8 deletions clients/snowflake/staging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ import (
"os"
"strings"

"github.com/artie-labs/transfer/clients/snowflake/dialect"

"github.com/stretchr/testify/assert"

"github.com/artie-labs/transfer/clients/shared"
"github.com/artie-labs/transfer/clients/snowflake/dialect"
"github.com/artie-labs/transfer/lib/config"
"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/destination/types"
Expand Down Expand Up @@ -160,31 +159,32 @@ func (s *SnowflakeTestSuite) TestPrepareTempTable() {

{
assert.NoError(s.T(), s.stageStore.PrepareTemporaryTable(context.Background(), tableData, sflkTc, tempTableID, tempTableID, types.AdditionalSettings{}, true))
assert.Equal(s.T(), 3, s.fakeStageStore.ExecCallCount())
assert.Equal(s.T(), 2, s.fakeStageStore.ExecCallCount())
assert.Equal(s.T(), 1, s.fakeStageStore.ExecContextCallCount())

// First call is to create the temp table
createQuery, _ := s.fakeStageStore.ExecArgsForCall(0)
_, createQuery, _ := s.fakeStageStore.ExecContextArgsForCall(0)

prefixQuery := fmt.Sprintf(
`CREATE TABLE IF NOT EXISTS %s ("USER_ID" string,"FIRST_NAME" string,"LAST_NAME" string,"DUSTY" string) STAGE_COPY_OPTIONS = ( PURGE = TRUE ) STAGE_FILE_FORMAT = ( TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE)`, tempTableName)
containsPrefix := strings.HasPrefix(createQuery, prefixQuery)
assert.True(s.T(), containsPrefix, fmt.Sprintf("createQuery:%v, prefixQuery:%s", createQuery, prefixQuery))
resourceName := addPrefixToTableName(tempTableID, "%")
// Second call is a PUT
putQuery, _ := s.fakeStageStore.ExecArgsForCall(1)
putQuery, _ := s.fakeStageStore.ExecArgsForCall(0)
assert.Contains(s.T(), putQuery, "PUT file://", putQuery)
assert.Contains(s.T(), putQuery, fmt.Sprintf("@%s AUTO_COMPRESS=TRUE", resourceName))
// Third call is a COPY INTO
copyQuery, _ := s.fakeStageStore.ExecArgsForCall(2)
copyQuery, _ := s.fakeStageStore.ExecArgsForCall(1)
assert.Equal(s.T(), fmt.Sprintf(`COPY INTO %s ("USER_ID","FIRST_NAME","LAST_NAME","DUSTY") FROM (SELECT $1,$2,$3,$4 FROM @%s)`,
tempTableName, resourceName), copyQuery)
}
{
// Don't create the temporary table.
assert.NoError(s.T(), s.stageStore.PrepareTemporaryTable(context.Background(), tableData, sflkTc, tempTableID, tempTableID, types.AdditionalSettings{}, false))
assert.Equal(s.T(), 5, s.fakeStageStore.ExecCallCount())
assert.Equal(s.T(), 4, s.fakeStageStore.ExecCallCount())
assert.Equal(s.T(), 1, s.fakeStageStore.ExecContextCallCount())
}

}

func (s *SnowflakeTestSuite) TestLoadTemporaryTable() {
Expand Down