Skip to content

Commit

Permalink
feat: Use new data types in sql builder for functions and procedures (#…
Browse files Browse the repository at this point in the history
…3247)

Use new data types directly in Opts structs:
- rename old data type fields and make them optional for functions and
procedures
- add new data type fields for functions and procedures
- adjust the resources
- adjust unit, integration, and acceptance tests
- add integration tests verifying in detail the behavior of data types
for functions and procedures
- adjust and test SQL builder
- add Canonical method to new data types (useful because of how
Snowflake returns data types)

Next PRs:
- adjust function/procedure SDK
- handle Table return format as data type?
- resources/datasources (these will be a few PRs)
  • Loading branch information
sfc-gh-asawicki authored Dec 6, 2024
1 parent 5df33a8 commit 69f677a
Show file tree
Hide file tree
Showing 47 changed files with 2,678 additions and 448 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ test-acceptance: ## run acceptance tests
TF_ACC=1 SF_TF_ACC_TEST_CONFIGURE_CLIENT_ONCE=true TEST_SF_TF_REQUIRE_TEST_OBJECT_SUFFIX=1 go test -run "^TestAcc_" -v -cover -timeout=120m ./...

test-integration: ## run SDK integration tests
TEST_SF_TF_REQUIRE_TEST_OBJECT_SUFFIX=1 go test -run "^TestInt_" -v -cover -timeout=45m ./...
TEST_SF_TF_REQUIRE_TEST_OBJECT_SUFFIX=1 go test -run "^TestInt_" -v -cover -timeout=60m ./...

test-architecture: ## check architecture constraints between packages
go test ./pkg/architests/... -v
Expand Down
6 changes: 3 additions & 3 deletions pkg/acceptance/helpers/function_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (c *FunctionClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObjectI
return c.CreateWithRequest(t, id,
sdk.NewCreateForSQLFunctionRequest(
id.SchemaObjectId(),
*sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeInt)),
*sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeInt)),
"SELECT 1",
),
)
Expand All @@ -48,7 +48,7 @@ func (c *FunctionClient) CreateSecure(t *testing.T, arguments ...sdk.DataType) *
return c.CreateWithRequest(t, id,
sdk.NewCreateForSQLFunctionRequest(
id.SchemaObjectId(),
*sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeInt)),
*sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeInt)),
"SELECT 1",
).WithSecure(true),
)
Expand All @@ -59,7 +59,7 @@ func (c *FunctionClient) CreateWithRequest(t *testing.T, id sdk.SchemaObjectIden
ctx := context.Background()
argumentRequests := make([]sdk.FunctionArgumentRequest, len(id.ArgumentDataTypes()))
for i, argumentDataType := range id.ArgumentDataTypes() {
argumentRequests[i] = *sdk.NewFunctionArgumentRequest(c.ids.Alpha(), argumentDataType)
argumentRequests[i] = *sdk.NewFunctionArgumentRequest(c.ids.Alpha(), nil).WithArgDataTypeOld(argumentDataType)
}
err := c.client().CreateForSQL(ctx, req.WithArguments(argumentRequests))
require.NoError(t, err)
Expand Down
4 changes: 2 additions & 2 deletions pkg/acceptance/helpers/procedure_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ func (c *ProcedureClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObject
ctx := context.Background()
argumentRequests := make([]sdk.ProcedureArgumentRequest, len(id.ArgumentDataTypes()))
for i, argumentDataType := range id.ArgumentDataTypes() {
argumentRequests[i] = *sdk.NewProcedureArgumentRequest(c.ids.Alpha(), argumentDataType)
argumentRequests[i] = *sdk.NewProcedureArgumentRequest(c.ids.Alpha(), nil).WithArgDataTypeOld(argumentDataType)
}
err := c.client().CreateForSQL(ctx,
sdk.NewCreateForSQLProcedureRequest(
id.SchemaObjectId(),
*sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeInt)),
*sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeInt)),
`BEGIN RETURN 1; END`).WithArguments(argumentRequests),
)
require.NoError(t, err)
Expand Down
66 changes: 33 additions & 33 deletions pkg/resources/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ func CreateContextFunction(ctx context.Context, d *schema.ResourceData, meta int
func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
client := meta.(*provider.Context).Client
name := d.Get("name").(string)
schema := d.Get("schema").(string)
sc := d.Get("schema").(string)
database := d.Get("database").(string)
id := sdk.NewSchemaObjectIdentifier(database, schema, name)
id := sdk.NewSchemaObjectIdentifier(database, sc, name)

// Set required
returns, diags := parseFunctionReturnsRequest(d.Get("return_type").(string))
Expand Down Expand Up @@ -266,14 +266,14 @@ func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interf
request.WithComment(v.(string))
}
if _, ok := d.GetOk("imports"); ok {
imports := []sdk.FunctionImportRequest{}
var imports []sdk.FunctionImportRequest
for _, item := range d.Get("imports").([]interface{}) {
imports = append(imports, *sdk.NewFunctionImportRequest().WithImport(item.(string)))
}
request.WithImports(imports)
}
if _, ok := d.GetOk("packages"); ok {
packages := []sdk.FunctionPackageRequest{}
var packages []sdk.FunctionPackageRequest
for _, item := range d.Get("packages").([]interface{}) {
packages = append(packages, *sdk.NewFunctionPackageRequest().WithPackage(item.(string)))
}
Expand All @@ -288,19 +288,19 @@ func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interf
}
argumentTypes := make([]sdk.DataType, 0, len(arguments))
for _, item := range arguments {
argumentTypes = append(argumentTypes, item.ArgDataType)
argumentTypes = append(argumentTypes, item.ArgDataTypeOld)
}
nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...)
nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...)
d.SetId(nid.FullyQualifiedName())
return ReadContextFunction(ctx, d, meta)
}

func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
client := meta.(*provider.Context).Client
name := d.Get("name").(string)
schema := d.Get("schema").(string)
sc := d.Get("schema").(string)
database := d.Get("database").(string)
id := sdk.NewSchemaObjectIdentifier(database, schema, name)
id := sdk.NewSchemaObjectIdentifier(database, sc, name)

// Set required
returnType := d.Get("return_type").(string)
Expand All @@ -311,7 +311,7 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter
functionDefinition := d.Get("statement").(string)
handler := d.Get("handler").(string)
// create request with required
request := sdk.NewCreateForScalaFunctionRequest(id, sdk.LegacyDataTypeFrom(returnDataType), handler)
request := sdk.NewCreateForScalaFunctionRequest(id, nil, handler).WithResultDataTypeOld(sdk.LegacyDataTypeFrom(returnDataType))
request.WithFunctionDefinition(functionDefinition)

// Set optionals
Expand All @@ -338,14 +338,14 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter
request.WithComment(v.(string))
}
if _, ok := d.GetOk("imports"); ok {
imports := []sdk.FunctionImportRequest{}
var imports []sdk.FunctionImportRequest
for _, item := range d.Get("imports").([]interface{}) {
imports = append(imports, *sdk.NewFunctionImportRequest().WithImport(item.(string)))
}
request.WithImports(imports)
}
if _, ok := d.GetOk("packages"); ok {
packages := []sdk.FunctionPackageRequest{}
var packages []sdk.FunctionPackageRequest
for _, item := range d.Get("packages").([]interface{}) {
packages = append(packages, *sdk.NewFunctionPackageRequest().WithPackage(item.(string)))
}
Expand All @@ -360,19 +360,19 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter
}
argumentTypes := make([]sdk.DataType, 0, len(arguments))
for _, item := range arguments {
argumentTypes = append(argumentTypes, item.ArgDataType)
argumentTypes = append(argumentTypes, item.ArgDataTypeOld)
}
nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...)
nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...)
d.SetId(nid.FullyQualifiedName())
return ReadContextFunction(ctx, d, meta)
}

func createSQLFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
client := meta.(*provider.Context).Client
name := d.Get("name").(string)
schema := d.Get("schema").(string)
sc := d.Get("schema").(string)
database := d.Get("database").(string)
id := sdk.NewSchemaObjectIdentifier(database, schema, name)
id := sdk.NewSchemaObjectIdentifier(database, sc, name)

// Set required
returns, diags := parseFunctionReturnsRequest(d.Get("return_type").(string))
Expand Down Expand Up @@ -406,19 +406,19 @@ func createSQLFunction(ctx context.Context, d *schema.ResourceData, meta interfa
}
argumentTypes := make([]sdk.DataType, 0, len(arguments))
for _, item := range arguments {
argumentTypes = append(argumentTypes, item.ArgDataType)
argumentTypes = append(argumentTypes, item.ArgDataTypeOld)
}
nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...)
nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...)
d.SetId(nid.FullyQualifiedName())
return ReadContextFunction(ctx, d, meta)
}

func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
client := meta.(*provider.Context).Client
name := d.Get("name").(string)
schema := d.Get("schema").(string)
sc := d.Get("schema").(string)
database := d.Get("database").(string)
id := sdk.NewSchemaObjectIdentifier(database, schema, name)
id := sdk.NewSchemaObjectIdentifier(database, sc, name)

// Set required
returns, diags := parseFunctionReturnsRequest(d.Get("return_type").(string))
Expand Down Expand Up @@ -454,14 +454,14 @@ func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta inte
request.WithComment(v.(string))
}
if _, ok := d.GetOk("imports"); ok {
imports := []sdk.FunctionImportRequest{}
var imports []sdk.FunctionImportRequest
for _, item := range d.Get("imports").([]interface{}) {
imports = append(imports, *sdk.NewFunctionImportRequest().WithImport(item.(string)))
}
request.WithImports(imports)
}
if _, ok := d.GetOk("packages"); ok {
packages := []sdk.FunctionPackageRequest{}
var packages []sdk.FunctionPackageRequest
for _, item := range d.Get("packages").([]interface{}) {
packages = append(packages, *sdk.NewFunctionPackageRequest().WithPackage(item.(string)))
}
Expand All @@ -473,19 +473,19 @@ func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta inte
}
argumentTypes := make([]sdk.DataType, 0, len(arguments))
for _, item := range arguments {
argumentTypes = append(argumentTypes, item.ArgDataType)
argumentTypes = append(argumentTypes, item.ArgDataTypeOld)
}
nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...)
nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...)
d.SetId(nid.FullyQualifiedName())
return ReadContextFunction(ctx, d, meta)
}

func createJavascriptFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
client := meta.(*provider.Context).Client
name := d.Get("name").(string)
schema := d.Get("schema").(string)
sc := d.Get("schema").(string)
database := d.Get("database").(string)
id := sdk.NewSchemaObjectIdentifier(database, schema, name)
id := sdk.NewSchemaObjectIdentifier(database, sc, name)

// Set required
returns, diags := parseFunctionReturnsRequest(d.Get("return_type").(string))
Expand Down Expand Up @@ -522,9 +522,9 @@ func createJavascriptFunction(ctx context.Context, d *schema.ResourceData, meta
}
argumentTypes := make([]sdk.DataType, 0, len(arguments))
for _, item := range arguments {
argumentTypes = append(argumentTypes, item.ArgDataType)
argumentTypes = append(argumentTypes, item.ArgDataTypeOld)
}
nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...)
nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...)
d.SetId(nid.FullyQualifiedName())
return ReadContextFunction(ctx, d, meta)
}
Expand Down Expand Up @@ -575,7 +575,7 @@ func ReadContextFunction(ctx context.Context, d *schema.ResourceData, meta inter
if value != "" { // Do nothing for functions without arguments
pairs := strings.Split(value, ", ")

arguments := []interface{}{}
var arguments []interface{}
for _, pair := range pairs {
item := strings.Split(pair, " ")
argument := map[string]interface{}{}
Expand Down Expand Up @@ -739,7 +739,7 @@ func parseFunctionArguments(d *schema.ResourceData) ([]sdk.FunctionArgumentReque
if diags != nil {
return nil, diags
}
args = append(args, sdk.FunctionArgumentRequest{ArgName: argName, ArgDataType: sdk.LegacyDataTypeFrom(argDataType)})
args = append(args, sdk.FunctionArgumentRequest{ArgName: argName, ArgDataTypeOld: sdk.LegacyDataTypeFrom(argDataType)})
}
}
return args, nil
Expand All @@ -764,8 +764,8 @@ func convertFunctionColumns(s string) ([]sdk.FunctionColumn, diag.Diagnostics) {
return nil, diag.FromErr(err)
}
columns = append(columns, sdk.FunctionColumn{
ColumnName: match[1],
ColumnDataType: sdk.LegacyDataTypeFrom(dataType),
ColumnName: match[1],
ColumnDataTypeOld: sdk.LegacyDataTypeFrom(dataType),
})
}
}
Expand All @@ -781,15 +781,15 @@ func parseFunctionReturnsRequest(s string) (*sdk.FunctionReturnsRequest, diag.Di
}
var cr []sdk.FunctionColumnRequest
for _, item := range columns {
cr = append(cr, *sdk.NewFunctionColumnRequest(item.ColumnName, item.ColumnDataType))
cr = append(cr, *sdk.NewFunctionColumnRequest(item.ColumnName, nil).WithColumnDataTypeOld(item.ColumnDataTypeOld))
}
returns.WithTable(*sdk.NewFunctionReturnsTableRequest().WithColumns(cr))
} else {
returnDataType, diags := convertFunctionDataType(s)
if diags != nil {
return nil, diags
}
returns.WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType)))
returns.WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.LegacyDataTypeFrom(returnDataType)))
}
return returns, nil
}
Loading

0 comments on commit 69f677a

Please sign in to comment.