Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support different tag association queries for COLUMN object types #1380

28 changes: 25 additions & 3 deletions pkg/snowflake/tag_association.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"strings"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/validation"
"github.com/jmoiron/sqlx"
Expand Down Expand Up @@ -49,14 +50,25 @@ func (tb *TagAssociationBuilder) GetTagDatabase() string {

// GetTagName returns the value of the tag name of TagAssociationBuilder.
func (tb *TagAssociationBuilder) GetTagName() string {
return tb.schemaName
return tb.tagName
}

// GetTagSchema returns the value of the tag schema of TagAssociationBuilder.
func (tb *TagAssociationBuilder) GetTagSchema() string {
return tb.schemaName
}

func (tb *TagAssociationBuilder) GetTableAndColumnName() (string, string) {
if strings.ToUpper(tb.objectType) != "COLUMN" {
return tb.objectIdentifier, ""
} else {
splObjIdentifier := strings.Split(tb.objectIdentifier, ".")
tableName := strings.ReplaceAll(splObjIdentifier[2], "\"", "")
columnName := strings.ReplaceAll(splObjIdentifier[3], "\"", "")
return fmt.Sprintf(`"%s"."%s"."%s"`, tb.databaseName, tb.schemaName, tableName), columnName
}
}

// TagAssociation returns a pointer to a Builder that abstracts the DDL operations for a tag sssociation.
//
// Supported DDL operations are:
Expand All @@ -76,12 +88,22 @@ func NewTagAssociationBuilder(tagID string) *TagAssociationBuilder {

// Create returns the SQL query that will set the tag on an object.
func (tb *TagAssociationBuilder) Create() string {
return fmt.Sprintf(`ALTER %v %v SET TAG "%v"."%v"."%v" = '%v'`, tb.objectType, tb.objectIdentifier, tb.databaseName, tb.schemaName, tb.tagName, EscapeString(tb.tagValue))
if strings.ToUpper(tb.objectType) == "COLUMN" {
tableName, columnName := tb.GetTableAndColumnName()
return fmt.Sprintf(`ALTER TABLE %v ALTER COLUMN %v SET TAG "%v"."%v"."%v" = '%v'`, tableName, columnName, tb.databaseName, tb.schemaName, tb.tagName, EscapeString(tb.tagValue))
rmorris1218 marked this conversation as resolved.
Show resolved Hide resolved
} else {
return fmt.Sprintf(`ALTER %v %v SET TAG "%v"."%v"."%v" = '%v'`, tb.objectType, tb.objectIdentifier, tb.databaseName, tb.schemaName, tb.tagName, EscapeString(tb.tagValue))
}
}

// Drop returns the SQL query that will remove a tag from an object.
func (tb *TagAssociationBuilder) Drop() string {
return fmt.Sprintf(`ALTER %v %v UNSET TAG "%v"."%v"."%v"`, tb.objectType, tb.objectIdentifier, tb.databaseName, tb.schemaName, tb.tagName)
if strings.ToUpper(tb.objectType) == "COLUMN" {
tableName, columnName := tb.GetTableAndColumnName()
return fmt.Sprintf(`ALTER TABLE %v ALTER COLUMN %v UNSET TAG "%v"."%v"."%v"`, tableName, columnName, tb.databaseName, tb.schemaName, tb.tagName)
rmorris1218 marked this conversation as resolved.
Show resolved Hide resolved
} else {
return fmt.Sprintf(`ALTER %v %v UNSET TAG "%v"."%v"."%v"`, tb.objectType, tb.objectIdentifier, tb.databaseName, tb.schemaName, tb.tagName)
}
}

// Show returns the SQL query that will show the current tag value on an object.
Expand Down
54 changes: 54 additions & 0 deletions pkg/snowflake/tag_association_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package snowflake

import (
"testing"

"github.com/stretchr/testify/require"
)

type TagAssociationTest struct {
Builder *TagAssociationBuilder
ExpectedCreate string
ExpectedDrop string
}

func TestTagAssociation(t *testing.T) {
tests := []TagAssociationTest{
{
Builder: NewTagAssociationBuilder("test_db|test_schema|sensitive").WithObjectIdentifier(`"test_db"."test_schema"."test_table"`).WithObjectType("TABLE").WithTagValue("true"),
ExpectedCreate: `ALTER TABLE "test_db"."test_schema"."test_table" SET TAG "test_db"."test_schema"."sensitive" = 'true'`,
ExpectedDrop: `ALTER TABLE "test_db"."test_schema"."test_table" UNSET TAG "test_db"."test_schema"."sensitive"`,
},
{
Builder: NewTagAssociationBuilder("test_db|test_schema|sensitive").WithObjectIdentifier(`"test_db"."test_schema"."test_table.important"`).WithObjectType("COLUMN").WithTagValue("true"),
ExpectedCreate: `ALTER TABLE "test_db"."test_schema"."test_table" ALTER COLUMN important SET TAG "test_db"."test_schema"."sensitive" = 'true'`,
rmorris1218 marked this conversation as resolved.
Show resolved Hide resolved
ExpectedDrop: `ALTER TABLE "test_db"."test_schema"."test_table" ALTER COLUMN important UNSET TAG "test_db"."test_schema"."sensitive"`},
rmorris1218 marked this conversation as resolved.
Show resolved Hide resolved
{
Builder: NewTagAssociationBuilder("OPERATION_DB|SECURITY|PII_2").WithObjectIdentifier(`"OPERATION_DB"."SECURITY"."test_table.important"`).WithObjectType("COLUMN").WithTagValue("true"),
ExpectedCreate: `ALTER TABLE "OPERATION_DB"."SECURITY"."test_table" ALTER COLUMN important SET TAG "OPERATION_DB"."SECURITY"."PII_2" = 'true'`,
rmorris1218 marked this conversation as resolved.
Show resolved Hide resolved
ExpectedDrop: `ALTER TABLE "OPERATION_DB"."SECURITY"."test_table" ALTER COLUMN important UNSET TAG "OPERATION_DB"."SECURITY"."PII_2"`},
rmorris1218 marked this conversation as resolved.
Show resolved Hide resolved
}
for _, testCase := range tests {
r := require.New(t)
r.Equal(testCase.ExpectedCreate, testCase.Builder.Create())
r.Equal(testCase.ExpectedDrop, testCase.Builder.Drop())
}
}

type TableColumnNameTest struct {
Builder *TagAssociationBuilder
expectedTableName, expectedColumnName string
}

func TestTableColumnName(t *testing.T) {
tests := []TableColumnNameTest{
{NewTagAssociationBuilder("a|b|sensitive").WithObjectIdentifier(`"a"."b"."c"`).WithObjectType("TABLE"), `"a"."b"."c"`, ""},
{NewTagAssociationBuilder("db|schema|sensitive").WithObjectIdentifier(`"db"."schema"."table.column"`).WithObjectType("COLUMN"), `"db"."schema"."table"`, "column"},
}
for _, testCase := range tests {
r := require.New(t)
tableName, columnName := testCase.Builder.GetTableAndColumnName()
r.Equal(testCase.expectedTableName, tableName)
r.Equal(testCase.expectedColumnName, columnName)
}
}