diff --git a/pkg/acceptance/helpers/information_schema_client.go b/pkg/acceptance/helpers/information_schema_client.go new file mode 100644 index 0000000000..9ed99e4e19 --- /dev/null +++ b/pkg/acceptance/helpers/information_schema_client.go @@ -0,0 +1,44 @@ +package helpers + +import ( + "context" + "fmt" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/stretchr/testify/require" +) + +type InformationSchemaClient struct { + context *TestClientContext + ids *IdsGenerator +} + +func NewInformationSchemaClient(context *TestClientContext, idsGenerator *IdsGenerator) *InformationSchemaClient { + return &InformationSchemaClient{ + context: context, + ids: idsGenerator, + } +} + +func (c *InformationSchemaClient) client() *sdk.Client { + return c.context.client +} + +func (c *InformationSchemaClient) GetQueryTextByQueryId(t *testing.T, queryId string) string { + t.Helper() + result, err := c.client().QueryUnsafe(context.Background(), fmt.Sprintf("SELECT QUERY_TEXT FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY(RESULT_LIMIT => 20)) WHERE QUERY_ID = '%s'", queryId)) + require.NoError(t, err) + require.Len(t, result, 1) + require.NotNil(t, result[0]["QUERY_TEXT"]) + return (*result[0]["QUERY_TEXT"]).(string) +} + +func (c *InformationSchemaClient) GetQueryTagByQueryId(t *testing.T, queryId string) string { + t.Helper() + result, err := c.client().QueryUnsafe(context.Background(), fmt.Sprintf("SELECT QUERY_TAG FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY(RESULT_LIMIT => 20)) WHERE QUERY_ID = '%s'", queryId)) + require.NoError(t, err) + require.Len(t, result, 1) + require.NotNil(t, result[0]["QUERY_TAG"]) + return (*result[0]["QUERY_TAG"]).(string) +} diff --git a/pkg/acceptance/helpers/test_client.go b/pkg/acceptance/helpers/test_client.go index 53a9b6cb2d..8c2a3cccb1 100644 --- a/pkg/acceptance/helpers/test_client.go +++ b/pkg/acceptance/helpers/test_client.go @@ -36,6 +36,7 @@ type TestClient struct { FileFormat *FileFormatClient Function *FunctionClient Grant *GrantClient + InformationSchema *InformationSchemaClient MaskingPolicy *MaskingPolicyClient MaterializedView *MaterializedViewClient NetworkPolicy *NetworkPolicyClient @@ -108,6 +109,7 @@ func NewTestClient(c *sdk.Client, database string, schema string, warehouse stri FileFormat: NewFileFormatClient(context, idsGenerator), Function: NewFunctionClient(context, idsGenerator), Grant: NewGrantClient(context, idsGenerator), + InformationSchema: NewInformationSchemaClient(context, idsGenerator), MaskingPolicy: NewMaskingPolicyClient(context, idsGenerator), MaterializedView: NewMaterializedViewClient(context, idsGenerator), NetworkPolicy: NewNetworkPolicyClient(context, idsGenerator), diff --git a/pkg/acceptance/helpers/user_client.go b/pkg/acceptance/helpers/user_client.go index c64afcf723..20461ae6e5 100644 --- a/pkg/acceptance/helpers/user_client.go +++ b/pkg/acceptance/helpers/user_client.go @@ -68,6 +68,14 @@ func (c *UserClient) Alter(t *testing.T, id sdk.AccountObjectIdentifier, opts *s require.NoError(t, err) } +func (c *UserClient) AlterCurrentUser(t *testing.T, opts *sdk.AlterUserOptions) { + t.Helper() + id, err := c.context.client.ContextFunctions.CurrentUser(context.Background()) + require.NoError(t, err) + err = c.client().Alter(context.Background(), id, opts) + require.NoError(t, err) +} + func (c *UserClient) DropUserFunc(t *testing.T, id sdk.AccountObjectIdentifier) func() { t.Helper() ctx := context.Background() diff --git a/pkg/internal/tracking/context.go b/pkg/internal/tracking/context.go new file mode 100644 index 0000000000..9519bf1bb4 --- /dev/null +++ b/pkg/internal/tracking/context.go @@ -0,0 +1,73 @@ +package tracking + +import ( + "context" + "errors" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" +) + +const ( + ProviderVersion string = "v0.99.0" // TODO(SNOW-1814934): Currently hardcoded, make it computed + MetadataPrefix string = "terraform_provider_usage_tracking" +) + +type key struct{} + +var metadataContextKey key + +type Operation string + +const ( + CreateOperation Operation = "create" + ReadOperation Operation = "read" + UpdateOperation Operation = "update" + DeleteOperation Operation = "delete" + ImportOperation Operation = "import" + CustomDiffOperation Operation = "custom_diff" +) + +type Metadata struct { + Version string `json:"version,omitempty"` + Resource string `json:"resource,omitempty"` + Operation Operation `json:"operation,omitempty"` +} + +func (m Metadata) validate() error { + errs := make([]error, 0) + if m.Version == "" { + errs = append(errs, errors.New("version for metadata should not be empty")) + } + if m.Resource == "" { + errs = append(errs, errors.New("resource name for metadata should not be empty")) + } + if m.Operation == "" { + errs = append(errs, errors.New("operation for metadata should not be empty")) + } + return errors.Join(errs...) +} + +func NewMetadata(version string, resource resources.Resource, operation Operation) Metadata { + return Metadata{ + Version: version, + Resource: resource.String(), + Operation: operation, + } +} + +func NewVersionedMetadata(resource resources.Resource, operation Operation) Metadata { + return Metadata{ + Version: ProviderVersion, + Resource: resource.String(), + Operation: operation, + } +} + +func NewContext(ctx context.Context, metadata Metadata) context.Context { + return context.WithValue(ctx, metadataContextKey, metadata) +} + +func FromContext(ctx context.Context) (Metadata, bool) { + metadata, ok := ctx.Value(metadataContextKey).(Metadata) + return metadata, ok +} diff --git a/pkg/internal/tracking/context_test.go b/pkg/internal/tracking/context_test.go new file mode 100644 index 0000000000..96e38f75a3 --- /dev/null +++ b/pkg/internal/tracking/context_test.go @@ -0,0 +1,45 @@ +package tracking + +import ( + "context" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/stretchr/testify/require" +) + +func Test_Context(t *testing.T) { + metadata := NewMetadata("123", resources.Account, CreateOperation) + newMetadata := NewMetadata("321", resources.Database, UpdateOperation) + ctx := context.Background() + + // no metadata in context + value := ctx.Value(metadataContextKey) + require.Nil(t, value) + + retrievedMetadata, ok := FromContext(ctx) + require.False(t, ok) + require.Empty(t, retrievedMetadata) + + // add metadata by hand + ctx = context.WithValue(ctx, metadataContextKey, metadata) + + value = ctx.Value(metadataContextKey) + require.NotNil(t, value) + require.Equal(t, metadata, value) + + retrievedMetadata, ok = FromContext(ctx) + require.True(t, ok) + require.Equal(t, metadata, retrievedMetadata) + + // add metadata with NewContext function (overrides previous value) + ctx = NewContext(ctx, newMetadata) + + value = ctx.Value(metadataContextKey) + require.NotNil(t, value) + require.Equal(t, newMetadata, value) + + retrievedMetadata, ok = FromContext(ctx) + require.True(t, ok) + require.Equal(t, newMetadata, retrievedMetadata) +} diff --git a/pkg/internal/tracking/query.go b/pkg/internal/tracking/query.go new file mode 100644 index 0000000000..e49421b1a9 --- /dev/null +++ b/pkg/internal/tracking/query.go @@ -0,0 +1,31 @@ +package tracking + +import ( + "encoding/json" + "fmt" + "strings" +) + +func AppendMetadata(sql string, metadata Metadata) (string, error) { + bytes, err := json.Marshal(metadata) + if err != nil { + return "", fmt.Errorf("failed to marshal the metadata: %w", err) + } else { + return fmt.Sprintf("%s --%s %s", sql, MetadataPrefix, string(bytes)), nil + } +} + +func ParseMetadata(sql string) (Metadata, error) { + parts := strings.Split(sql, fmt.Sprintf("--%s", MetadataPrefix)) + if len(parts) != 2 { + return Metadata{}, fmt.Errorf("failed to parse metadata from sql, incorrect number of parts, expected: 2, got: %d", len(parts)) + } + var metadata Metadata + if err := json.Unmarshal([]byte(strings.TrimSpace(parts[1])), &metadata); err != nil { + return Metadata{}, fmt.Errorf("failed to unmarshal metadata from sql: %s, err = %w", sql, err) + } + if err := metadata.validate(); err != nil { + return Metadata{}, err + } + return metadata, nil +} diff --git a/pkg/internal/tracking/query_test.go b/pkg/internal/tracking/query_test.go new file mode 100644 index 0000000000..6d46162186 --- /dev/null +++ b/pkg/internal/tracking/query_test.go @@ -0,0 +1,65 @@ +package tracking + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/stretchr/testify/require" +) + +func TestAppendMetadata(t *testing.T) { + metadata := NewMetadata("123", resources.Account, CreateOperation) + sql := "SELECT 1" + + bytes, err := json.Marshal(metadata) + require.NoError(t, err) + + expectedSql := fmt.Sprintf("%s --%s %s", sql, MetadataPrefix, string(bytes)) + + newSql, err := AppendMetadata(sql, metadata) + require.NoError(t, err) + require.Equal(t, expectedSql, newSql) +} + +func TestParseMetadata(t *testing.T) { + metadata := NewMetadata("123", resources.Account, CreateOperation) + bytes, err := json.Marshal(metadata) + require.NoError(t, err) + sql := fmt.Sprintf("SELECT 1 --%s %s", MetadataPrefix, string(bytes)) + + parsedMetadata, err := ParseMetadata(sql) + require.NoError(t, err) + require.Equal(t, metadata, parsedMetadata) +} + +func TestParseInvalidMetadataKeys(t *testing.T) { + sql := fmt.Sprintf(`SELECT 1 --%s {"key": "value"}`, MetadataPrefix) + + parsedMetadata, err := ParseMetadata(sql) + require.ErrorContains(t, err, "version for metadata should not be empty") + require.ErrorContains(t, err, "resource name for metadata should not be empty") + require.ErrorContains(t, err, "operation for metadata should not be empty") + require.Equal(t, Metadata{}, parsedMetadata) +} + +func TestParseInvalidMetadataJson(t *testing.T) { + sql := fmt.Sprintf(`SELECT 1 --%s "key": "value"`, MetadataPrefix) + + parsedMetadata, err := ParseMetadata(sql) + require.ErrorContains(t, err, "failed to unmarshal metadata from sql") + require.Equal(t, Metadata{}, parsedMetadata) +} + +func TestParseMetadataFromInvalidSqlCommentPrefix(t *testing.T) { + metadata := NewMetadata("123", resources.Account, CreateOperation) + sql := "SELECT 1" + + bytes, err := json.Marshal(metadata) + require.NoError(t, err) + + parsedMetadata, err := ParseMetadata(fmt.Sprintf("%s --invalid_prefix %s", sql, string(bytes))) + require.ErrorContains(t, err, "failed to parse metadata from sql") + require.Equal(t, Metadata{}, parsedMetadata) +} diff --git a/pkg/resources/common.go b/pkg/resources/common.go index 8a5df06f11..36a1da648a 100644 --- a/pkg/resources/common.go +++ b/pkg/resources/common.go @@ -5,6 +5,10 @@ import ( "regexp" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/go-cty/cty" @@ -101,3 +105,45 @@ func ImportName[T sdk.AccountObjectIdentifier | sdk.DatabaseObjectIdentifier | s return []*schema.ResourceData{d}, nil } + +func TrackingImportWrapper(resourceName resources.Resource, importImplementation schema.StateContextFunc) schema.StateContextFunc { + return func(ctx context.Context, d *schema.ResourceData, meta any) ([]*schema.ResourceData, error) { + ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.ImportOperation)) + return importImplementation(ctx, d, meta) + } +} + +func TrackingCreateWrapper(resourceName resources.Resource, createImplementation schema.CreateContextFunc) schema.CreateContextFunc { + return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { + ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.CreateOperation)) + return createImplementation(ctx, d, meta) + } +} + +func TrackingReadWrapper(resourceName resources.Resource, readImplementation schema.ReadContextFunc) schema.ReadContextFunc { + return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { + ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.ReadOperation)) + return readImplementation(ctx, d, meta) + } +} + +func TrackingUpdateWrapper(resourceName resources.Resource, updateImplementation schema.UpdateContextFunc) schema.UpdateContextFunc { + return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { + ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.UpdateOperation)) + return updateImplementation(ctx, d, meta) + } +} + +func TrackingDeleteWrapper(resourceName resources.Resource, deleteImplementation schema.DeleteContextFunc) schema.DeleteContextFunc { + return func(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics { + ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.DeleteOperation)) + return deleteImplementation(ctx, d, meta) + } +} + +func TrackingCustomDiffWrapper(resourceName resources.Resource, customdiffImplementation schema.CustomizeDiffFunc) schema.CustomizeDiffFunc { + return func(ctx context.Context, diff *schema.ResourceDiff, meta any) error { + ctx = tracking.NewContext(ctx, tracking.NewVersionedMetadata(resourceName, tracking.CustomDiffOperation)) + return customdiffImplementation(ctx, diff, meta) + } +} diff --git a/pkg/resources/schema.go b/pkg/resources/schema.go index 27158b069d..e406eb31e2 100644 --- a/pkg/resources/schema.go +++ b/pkg/resources/schema.go @@ -8,6 +8,8 @@ import ( "slices" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" @@ -89,23 +91,24 @@ var schemaSchema = map[string]*schema.Schema{ // Schema returns a pointer to the resource representing a schema. func Schema() *schema.Resource { return &schema.Resource{ - CreateContext: CreateContextSchema, - ReadContext: ReadContextSchema(true), - UpdateContext: UpdateContextSchema, - DeleteContext: DeleteContextSchema, + CreateContext: TrackingCreateWrapper(resources.Schema, CreateContextSchema), + ReadContext: TrackingReadWrapper(resources.Schema, ReadContextSchema(true)), + UpdateContext: TrackingUpdateWrapper(resources.Schema, UpdateContextSchema), + DeleteContext: TrackingDeleteWrapper(resources.Schema, DeleteContextSchema), Description: "Resource used to manage schema objects. For more information, check [schema documentation](https://docs.snowflake.com/en/sql-reference/sql/create-schema).", - CustomizeDiff: customdiff.All( + CustomizeDiff: TrackingCustomDiffWrapper(resources.Schema, customdiff.All( ComputedIfAnyAttributeChanged(schemaSchema, ShowOutputAttributeName, "name", "comment", "with_managed_access", "is_transient"), ComputedIfAnyAttributeChanged(schemaSchema, DescribeOutputAttributeName, "name"), ComputedIfAnyAttributeChanged(schemaSchema, FullyQualifiedNameAttributeName, "name"), ComputedIfAnyAttributeChanged(schemaParametersSchema, ParametersAttributeName, collections.Map(sdk.AsStringList(sdk.AllSchemaParameters), strings.ToLower)...), + // TODO(SNOW-1804424 - next pr): handle custom context in parameters customdiff schemaParametersCustomDiff, - ), + )), Schema: collections.MergeMaps(schemaSchema, schemaParametersSchema), Importer: &schema.ResourceImporter{ - StateContext: ImportSchema, + StateContext: TrackingImportWrapper(resources.Schema, ImportSchema), }, SchemaVersion: 2, diff --git a/pkg/sdk/client.go b/pkg/sdk/client.go index 8f6d66c8a4..134313439d 100644 --- a/pkg/sdk/client.go +++ b/pkg/sdk/client.go @@ -8,6 +8,8 @@ import ( "os" "slices" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/snowflakeenvs" "github.com/jmoiron/sqlx" "github.com/luna-duclos/instrumentedsql" @@ -132,7 +134,7 @@ func NewClient(cfg *gosnowflake.Config) (*Client, error) { logger := instrumentedsql.LoggerFunc(func(ctx context.Context, s string, kv ...interface{}) { switch s { case "sql-conn-query", "sql-conn-exec": - log.Printf("[DEBUG] %s: %v (%s)\n", s, kv, ctx.Value(snowflakeAccountLocatorContextKey)) + log.Printf("[DEBUG] %s: %v (%s)\n", s, kv, ctx.Value(SnowflakeAccountLocatorContextKey)) default: return } @@ -264,11 +266,9 @@ func (c *Client) Close() error { return nil } -type snowflakeAccountLocatorContext string +type ContextKey string -const ( - snowflakeAccountLocatorContextKey snowflakeAccountLocatorContext = "snowflake_account_locator" -) +const SnowflakeAccountLocatorContextKey ContextKey = "snowflake_account_locator" // Exec executes a query that does not return rows. func (c *Client) exec(ctx context.Context, sql string) (sql.Result, error) { @@ -277,7 +277,8 @@ func (c *Client) exec(ctx context.Context, sql string) (sql.Result, error) { log.Printf("[DEBUG] sql-conn-exec-dry: %v\n", sql) return nil, nil } - ctx = context.WithValue(ctx, snowflakeAccountLocatorContextKey, c.accountLocator) + ctx = context.WithValue(ctx, SnowflakeAccountLocatorContextKey, c.accountLocator) + sql = appendQueryMetadata(ctx, sql) result, err := c.db.ExecContext(ctx, sql) return result, decodeDriverError(err) } @@ -289,7 +290,8 @@ func (c *Client) query(ctx context.Context, dest interface{}, sql string) error log.Printf("[DEBUG] sql-conn-query-dry: %v\n", sql) return nil } - ctx = context.WithValue(ctx, snowflakeAccountLocatorContextKey, c.accountLocator) + ctx = context.WithValue(ctx, SnowflakeAccountLocatorContextKey, c.accountLocator) + sql = appendQueryMetadata(ctx, sql) return decodeDriverError(c.db.SelectContext(ctx, dest, sql)) } @@ -300,6 +302,19 @@ func (c *Client) queryOne(ctx context.Context, dest interface{}, sql string) err log.Printf("[DEBUG] sql-conn-query-one-dry: %v\n", sql) return nil } - ctx = context.WithValue(ctx, snowflakeAccountLocatorContextKey, c.accountLocator) + ctx = context.WithValue(ctx, SnowflakeAccountLocatorContextKey, c.accountLocator) + sql = appendQueryMetadata(ctx, sql) return decodeDriverError(c.db.GetContext(ctx, dest, sql)) } + +func appendQueryMetadata(ctx context.Context, sql string) string { + if metadata, ok := tracking.FromContext(ctx); ok { + newSql, err := tracking.AppendMetadata(sql, metadata) + if err != nil { + log.Printf("[ERROR] failed to append metadata tracking: %v\n", err) + return sql + } + return newSql + } + return sql +} diff --git a/pkg/sdk/context_functions.go b/pkg/sdk/context_functions.go index 1983189eb0..bbf39a23e3 100644 --- a/pkg/sdk/context_functions.go +++ b/pkg/sdk/context_functions.go @@ -20,6 +20,8 @@ type ContextFunctions interface { CurrentSession(ctx context.Context) (string, error) CurrentUser(ctx context.Context) (AccountObjectIdentifier, error) CurrentSessionDetails(ctx context.Context) (*CurrentSessionDetails, error) + + // TODO(SNOW-1805152): Remove this and utilize gosnowflake.WithQueryIDChan instead whenever query id is needed LastQueryId(ctx context.Context) (string, error) // Session Object functions. diff --git a/pkg/sdk/integration_test_imports.go b/pkg/sdk/integration_test_imports.go index aaf396b739..ab759dbce1 100644 --- a/pkg/sdk/integration_test_imports.go +++ b/pkg/sdk/integration_test_imports.go @@ -12,26 +12,22 @@ import ( // All the contents of this file were added to be able to use them outside the sdk package (i.e. integration tests package). // It was easier to do it that way, so that we do not include big rename changes in the first moving PR. -// ExecForTests is an exact copy of exec (that is unexported), that some integration tests/helpers were using +// ExecForTests is forwarding function for Client.exec (that is unexported), that some integration tests/helpers were using // TODO: remove after we have all usages covered by SDK (for now it means implementing stages, tables, and tags) func (c *Client) ExecForTests(ctx context.Context, sql string) (sql.Result, error) { - ctx = context.WithValue(ctx, snowflakeAccountLocatorContextKey, c.accountLocator) - result, err := c.db.ExecContext(ctx, sql) - return result, decodeDriverError(err) + return c.exec(ctx, sql) } -// QueryOneForTests is an exact copy of queryOne (that is unexported), that some integration tests/helpers were using +// QueryOneForTests is forwarding function for Client.queryOne (that is unexported), that some integration tests/helpers were using // TODO: remove after introducing all resources using this func (c *Client) QueryOneForTests(ctx context.Context, dest interface{}, sql string) error { - ctx = context.WithValue(ctx, snowflakeAccountLocatorContextKey, c.accountLocator) - return decodeDriverError(c.db.GetContext(ctx, dest, sql)) + return c.queryOne(ctx, dest, sql) } -// QueryForTests is an exact copy of query (that is unexported), that some integration tests/helpers were using +// QueryForTests is forwarding function for Client.query (that is unexported), that some integration tests/helpers were using // TODO: remove after introducing all resources using this func (c *Client) QueryForTests(ctx context.Context, dest interface{}, sql string) error { - ctx = context.WithValue(ctx, snowflakeAccountLocatorContextKey, c.accountLocator) - return decodeDriverError(c.db.SelectContext(ctx, dest, sql)) + return c.query(ctx, dest, sql) } func ErrorsEqual(t *testing.T, expected error, actual error) { diff --git a/pkg/sdk/testint/basic_object_tracking_integration_test.go b/pkg/sdk/testint/basic_object_tracking_integration_test.go new file mode 100644 index 0000000000..673eb31df3 --- /dev/null +++ b/pkg/sdk/testint/basic_object_tracking_integration_test.go @@ -0,0 +1,113 @@ +package testint + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/snowflakedb/gosnowflake" + "github.com/stretchr/testify/require" +) + +// Research for basic object tracking done as part of SNOW-1737787 + +// https://docs.snowflake.com/en/sql-reference/parameters#query-tag +func TestInt_ContextQueryTags(t *testing.T) { + client := testClient(t) + ctx := context.Background() + + // set query_tag on user level + userQueryTag := "user query tag" + testClientHelper().User.AlterCurrentUser(t, &sdk.AlterUserOptions{ + Set: &sdk.UserSet{ + SessionParameters: &sdk.SessionParameters{ + QueryTag: sdk.String(userQueryTag), + }, + }, + }) + t.Cleanup(func() { + testClientHelper().User.AlterCurrentUser(t, &sdk.AlterUserOptions{ + Unset: &sdk.UserUnset{ + SessionParameters: &sdk.SessionParametersUnset{ + QueryTag: sdk.Bool(true), + }, + }, + }) + }) + queryId := executeQueryAndReturnQueryId(t, context.Background(), client) + queryTagResult := testClientHelper().InformationSchema.GetQueryTagByQueryId(t, queryId) + require.Equal(t, userQueryTag, queryTagResult) + + // set query_tag on session level + sessionQueryTag := "session query tag" + require.NoError(t, client.Sessions.AlterSession(ctx, &sdk.AlterSessionOptions{ + Set: &sdk.SessionSet{ + SessionParameters: &sdk.SessionParameters{ + QueryTag: sdk.String(sessionQueryTag), + }, + }, + })) + t.Cleanup(func() { + require.NoError(t, client.Sessions.AlterSession(ctx, &sdk.AlterSessionOptions{ + Unset: &sdk.SessionUnset{ + SessionParametersUnset: &sdk.SessionParametersUnset{ + QueryTag: sdk.Bool(true), + }, + }, + })) + }) + queryId = executeQueryAndReturnQueryId(t, context.Background(), client) + queryTagResult = testClientHelper().InformationSchema.GetQueryTagByQueryId(t, queryId) + require.Equal(t, sessionQueryTag, queryTagResult) + + // set query_tag on query level + perQueryQueryTag := "per-query query tag" + ctxWithQueryTag := gosnowflake.WithQueryTag(context.Background(), perQueryQueryTag) + queryId = executeQueryAndReturnQueryId(t, ctxWithQueryTag, client) + queryTagResult = testClientHelper().InformationSchema.GetQueryTagByQueryId(t, queryId) + require.Equal(t, perQueryQueryTag, queryTagResult) +} + +func executeQueryAndReturnQueryId(t *testing.T, ctx context.Context, client *sdk.Client) string { + t.Helper() + queryIdChan := make(chan string, 1) + ctx = gosnowflake.WithQueryIDChan(ctx, queryIdChan) + + _, err := client.QueryUnsafe(ctx, "SELECT 1") + require.NoError(t, err) + + return <-queryIdChan +} + +// https://select.dev/posts/snowflake-query-tags#using-query-comments-instead-of-query-tags +func TestInt_QueryComment(t *testing.T) { + client := testClient(t) + ctx := context.Background() + + queryIdChan := make(chan string, 1) + metadata := `{"comment": "some comment"}` + _, err := client.QueryUnsafe(gosnowflake.WithQueryIDChan(ctx, queryIdChan), fmt.Sprintf(`SELECT 1; --%s`, metadata)) + require.NoError(t, err) + queryId := <-queryIdChan + + queryText := testClientHelper().InformationSchema.GetQueryTextByQueryId(t, queryId) + require.Equal(t, metadata, strings.Split(queryText, "--")[1]) +} + +func TestInt_AppName(t *testing.T) { + // https://community.snowflake.com/s/article/How-to-see-application-name-added-in-the-connection-string-in-Snowsight + t.Skip("there no way to check client application name by querying Snowflake's") + + version := "v0.99.0" + config := sdk.DefaultConfig() + config.Application = fmt.Sprintf("terraform-provider-snowflake:%s", version) + client, err := sdk.NewClient(config) + require.NoError(t, err) + + _, err = client.QueryUnsafe(context.Background(), "SELECT 1") + require.NoError(t, err) +} + +// TODO(SNOW-1805150): Document potential usage of connection string diff --git a/pkg/sdk/testint/client_integration_test.go b/pkg/sdk/testint/client_integration_test.go new file mode 100644 index 0000000000..47a38e5449 --- /dev/null +++ b/pkg/sdk/testint/client_integration_test.go @@ -0,0 +1,63 @@ +package testint + +import ( + "context" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" + "github.com/snowflakedb/gosnowflake" + "github.com/stretchr/testify/require" +) + +func TestInt_Client_AdditionalMetadata(t *testing.T) { + client := testClient(t) + metadata := tracking.NewMetadata("v1.13.1002-rc-test", resources.Database, tracking.CreateOperation) + + assertQueryMetadata := func(t *testing.T, queryId string) { + t.Helper() + queryText := testClientHelper().InformationSchema.GetQueryTextByQueryId(t, queryId) + parsedMetadata, err := tracking.ParseMetadata(queryText) + require.NoError(t, err) + require.Equal(t, metadata, parsedMetadata) + } + + t.Run("query one", func(t *testing.T) { + queryIdChan := make(chan string, 1) + ctx := context.Background() + ctx = tracking.NewContext(ctx, metadata) + ctx = gosnowflake.WithQueryIDChan(ctx, queryIdChan) + row := struct { + One int `db:"ONE"` + }{} + err := client.QueryOneForTests(ctx, &row, "SELECT 1 AS ONE") + require.NoError(t, err) + + assertQueryMetadata(t, <-queryIdChan) + }) + + t.Run("query", func(t *testing.T) { + queryIdChan := make(chan string, 1) + ctx := context.Background() + ctx = tracking.NewContext(ctx, metadata) + ctx = gosnowflake.WithQueryIDChan(ctx, queryIdChan) + var rows []struct { + One int `db:"ONE"` + } + err := client.QueryForTests(ctx, &rows, "SELECT 1 AS ONE") + require.NoError(t, err) + + assertQueryMetadata(t, <-queryIdChan) + }) + + t.Run("exec", func(t *testing.T) { + queryIdChan := make(chan string, 1) + ctx := context.Background() + ctx = tracking.NewContext(ctx, metadata) + ctx = gosnowflake.WithQueryIDChan(ctx, queryIdChan) + _, err := client.ExecForTests(ctx, "SELECT 1") + require.NoError(t, err) + + assertQueryMetadata(t, <-queryIdChan) + }) +}