From 69f677a6f86faa79cdece4d422eb61284c1599a6 Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Fri, 6 Dec 2024 11:03:29 +0100 Subject: [PATCH] feat: Use new data types in sql builder for functions and procedures (#3247) Use new data types directly in Opts structs: - rename old data type fields and make them optional for functions and procedures - add new data type fields for functions and procedures - adjust the resources - adjust unit, integration, and acceptance tests - add integration tests verifying in detail the behavior of data types for functions and procedures - adjust and test SQL builder - add Canonical method to new data types (useful because of how Snowflake returns data types) Next PRs: - adjust function/procedure SDK - handle Table return format as data type? - resources/datasources (these will be a few PRs) --- Makefile | 2 +- pkg/acceptance/helpers/function_client.go | 6 +- pkg/acceptance/helpers/procedure_client.go | 4 +- pkg/resources/function.go | 66 +- pkg/resources/procedure.go | 60 +- pkg/schemas/function_gen.go | 2 +- pkg/schemas/procedure_gen.go | 2 +- pkg/sdk/datatypes/array.go | 4 + pkg/sdk/datatypes/binary.go | 4 + pkg/sdk/datatypes/boolean.go | 4 + pkg/sdk/datatypes/data_types.go | 4 + pkg/sdk/datatypes/data_types_test.go | 23 +- pkg/sdk/datatypes/date.go | 4 + pkg/sdk/datatypes/float.go | 4 + pkg/sdk/datatypes/geography.go | 4 + pkg/sdk/datatypes/geometry.go | 4 + pkg/sdk/datatypes/number.go | 7 + pkg/sdk/datatypes/object.go | 4 + pkg/sdk/datatypes/text.go | 4 + pkg/sdk/datatypes/time.go | 4 + pkg/sdk/datatypes/timestamp_ltz.go | 4 + pkg/sdk/datatypes/timestamp_ntz.go | 4 + pkg/sdk/datatypes/timestamp_tz.go | 4 + pkg/sdk/datatypes/variant.go | 4 + pkg/sdk/datatypes/vector.go | 4 + pkg/sdk/functions_def.go | 23 +- pkg/sdk/functions_dto_builders_gen.go | 33 +- pkg/sdk/functions_dto_gen.go | 20 +- pkg/sdk/functions_ext.go | 5 + pkg/sdk/functions_gen.go | 28 +- pkg/sdk/functions_gen_test.go | 613 ++++++++- pkg/sdk/functions_impl_gen.go | 15 +- pkg/sdk/functions_validations_gen.go | 106 ++ pkg/sdk/poc/README.md | 1 + pkg/sdk/procedures_def.go | 35 +- pkg/sdk/procedures_dto_builders_gen.go | 40 +- pkg/sdk/procedures_dto_gen.go | 28 +- pkg/sdk/procedures_ext.go | 5 + pkg/sdk/procedures_gen.go | 36 +- pkg/sdk/procedures_gen_test.go | 1097 +++++++++++++++-- pkg/sdk/procedures_impl_gen.go | 56 +- pkg/sdk/procedures_validations_gen.go | 214 +++- pkg/sdk/random_test.go | 7 + pkg/sdk/sql_builder.go | 12 +- pkg/sdk/sql_builder_test.go | 125 ++ pkg/sdk/testint/functions_integration_test.go | 164 ++- .../testint/procedures_integration_test.go | 227 ++-- 47 files changed, 2678 insertions(+), 448 deletions(-) create mode 100644 pkg/sdk/functions_ext.go create mode 100644 pkg/sdk/procedures_ext.go diff --git a/Makefile b/Makefile index 1f914dc5d9..e8bd91a003 100644 --- a/Makefile +++ b/Makefile @@ -68,7 +68,7 @@ test-acceptance: ## run acceptance tests TF_ACC=1 SF_TF_ACC_TEST_CONFIGURE_CLIENT_ONCE=true TEST_SF_TF_REQUIRE_TEST_OBJECT_SUFFIX=1 go test -run "^TestAcc_" -v -cover -timeout=120m ./... test-integration: ## run SDK integration tests - TEST_SF_TF_REQUIRE_TEST_OBJECT_SUFFIX=1 go test -run "^TestInt_" -v -cover -timeout=45m ./... + TEST_SF_TF_REQUIRE_TEST_OBJECT_SUFFIX=1 go test -run "^TestInt_" -v -cover -timeout=60m ./... test-architecture: ## check architecture constraints between packages go test ./pkg/architests/... -v diff --git a/pkg/acceptance/helpers/function_client.go b/pkg/acceptance/helpers/function_client.go index 5b23afaf9f..3e6fe5a294 100644 --- a/pkg/acceptance/helpers/function_client.go +++ b/pkg/acceptance/helpers/function_client.go @@ -35,7 +35,7 @@ func (c *FunctionClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObjectI return c.CreateWithRequest(t, id, sdk.NewCreateForSQLFunctionRequest( id.SchemaObjectId(), - *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeInt)), + *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeInt)), "SELECT 1", ), ) @@ -48,7 +48,7 @@ func (c *FunctionClient) CreateSecure(t *testing.T, arguments ...sdk.DataType) * return c.CreateWithRequest(t, id, sdk.NewCreateForSQLFunctionRequest( id.SchemaObjectId(), - *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeInt)), + *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeInt)), "SELECT 1", ).WithSecure(true), ) @@ -59,7 +59,7 @@ func (c *FunctionClient) CreateWithRequest(t *testing.T, id sdk.SchemaObjectIden ctx := context.Background() argumentRequests := make([]sdk.FunctionArgumentRequest, len(id.ArgumentDataTypes())) for i, argumentDataType := range id.ArgumentDataTypes() { - argumentRequests[i] = *sdk.NewFunctionArgumentRequest(c.ids.Alpha(), argumentDataType) + argumentRequests[i] = *sdk.NewFunctionArgumentRequest(c.ids.Alpha(), nil).WithArgDataTypeOld(argumentDataType) } err := c.client().CreateForSQL(ctx, req.WithArguments(argumentRequests)) require.NoError(t, err) diff --git a/pkg/acceptance/helpers/procedure_client.go b/pkg/acceptance/helpers/procedure_client.go index e9a4375f2d..34aec170f7 100644 --- a/pkg/acceptance/helpers/procedure_client.go +++ b/pkg/acceptance/helpers/procedure_client.go @@ -34,12 +34,12 @@ func (c *ProcedureClient) CreateWithIdentifier(t *testing.T, id sdk.SchemaObject ctx := context.Background() argumentRequests := make([]sdk.ProcedureArgumentRequest, len(id.ArgumentDataTypes())) for i, argumentDataType := range id.ArgumentDataTypes() { - argumentRequests[i] = *sdk.NewProcedureArgumentRequest(c.ids.Alpha(), argumentDataType) + argumentRequests[i] = *sdk.NewProcedureArgumentRequest(c.ids.Alpha(), nil).WithArgDataTypeOld(argumentDataType) } err := c.client().CreateForSQL(ctx, sdk.NewCreateForSQLProcedureRequest( id.SchemaObjectId(), - *sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeInt)), + *sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeInt)), `BEGIN RETURN 1; END`).WithArguments(argumentRequests), ) require.NoError(t, err) diff --git a/pkg/resources/function.go b/pkg/resources/function.go index 19415f91f1..ba91184217 100644 --- a/pkg/resources/function.go +++ b/pkg/resources/function.go @@ -227,9 +227,9 @@ func CreateContextFunction(ctx context.Context, d *schema.ResourceData, meta int func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) + id := sdk.NewSchemaObjectIdentifier(database, sc, name) // Set required returns, diags := parseFunctionReturnsRequest(d.Get("return_type").(string)) @@ -266,14 +266,14 @@ func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interf request.WithComment(v.(string)) } if _, ok := d.GetOk("imports"); ok { - imports := []sdk.FunctionImportRequest{} + var imports []sdk.FunctionImportRequest for _, item := range d.Get("imports").([]interface{}) { imports = append(imports, *sdk.NewFunctionImportRequest().WithImport(item.(string))) } request.WithImports(imports) } if _, ok := d.GetOk("packages"); ok { - packages := []sdk.FunctionPackageRequest{} + var packages []sdk.FunctionPackageRequest for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewFunctionPackageRequest().WithPackage(item.(string))) } @@ -288,9 +288,9 @@ func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interf } argumentTypes := make([]sdk.DataType, 0, len(arguments)) for _, item := range arguments { - argumentTypes = append(argumentTypes, item.ArgDataType) + argumentTypes = append(argumentTypes, item.ArgDataTypeOld) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -298,9 +298,9 @@ func createJavaFunction(ctx context.Context, d *schema.ResourceData, meta interf func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) + id := sdk.NewSchemaObjectIdentifier(database, sc, name) // Set required returnType := d.Get("return_type").(string) @@ -311,7 +311,7 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter functionDefinition := d.Get("statement").(string) handler := d.Get("handler").(string) // create request with required - request := sdk.NewCreateForScalaFunctionRequest(id, sdk.LegacyDataTypeFrom(returnDataType), handler) + request := sdk.NewCreateForScalaFunctionRequest(id, nil, handler).WithResultDataTypeOld(sdk.LegacyDataTypeFrom(returnDataType)) request.WithFunctionDefinition(functionDefinition) // Set optionals @@ -338,14 +338,14 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter request.WithComment(v.(string)) } if _, ok := d.GetOk("imports"); ok { - imports := []sdk.FunctionImportRequest{} + var imports []sdk.FunctionImportRequest for _, item := range d.Get("imports").([]interface{}) { imports = append(imports, *sdk.NewFunctionImportRequest().WithImport(item.(string))) } request.WithImports(imports) } if _, ok := d.GetOk("packages"); ok { - packages := []sdk.FunctionPackageRequest{} + var packages []sdk.FunctionPackageRequest for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewFunctionPackageRequest().WithPackage(item.(string))) } @@ -360,9 +360,9 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter } argumentTypes := make([]sdk.DataType, 0, len(arguments)) for _, item := range arguments { - argumentTypes = append(argumentTypes, item.ArgDataType) + argumentTypes = append(argumentTypes, item.ArgDataTypeOld) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -370,9 +370,9 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter func createSQLFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) + id := sdk.NewSchemaObjectIdentifier(database, sc, name) // Set required returns, diags := parseFunctionReturnsRequest(d.Get("return_type").(string)) @@ -406,9 +406,9 @@ func createSQLFunction(ctx context.Context, d *schema.ResourceData, meta interfa } argumentTypes := make([]sdk.DataType, 0, len(arguments)) for _, item := range arguments { - argumentTypes = append(argumentTypes, item.ArgDataType) + argumentTypes = append(argumentTypes, item.ArgDataTypeOld) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -416,9 +416,9 @@ func createSQLFunction(ctx context.Context, d *schema.ResourceData, meta interfa func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) + id := sdk.NewSchemaObjectIdentifier(database, sc, name) // Set required returns, diags := parseFunctionReturnsRequest(d.Get("return_type").(string)) @@ -454,14 +454,14 @@ func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta inte request.WithComment(v.(string)) } if _, ok := d.GetOk("imports"); ok { - imports := []sdk.FunctionImportRequest{} + var imports []sdk.FunctionImportRequest for _, item := range d.Get("imports").([]interface{}) { imports = append(imports, *sdk.NewFunctionImportRequest().WithImport(item.(string))) } request.WithImports(imports) } if _, ok := d.GetOk("packages"); ok { - packages := []sdk.FunctionPackageRequest{} + var packages []sdk.FunctionPackageRequest for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewFunctionPackageRequest().WithPackage(item.(string))) } @@ -473,9 +473,9 @@ func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta inte } argumentTypes := make([]sdk.DataType, 0, len(arguments)) for _, item := range arguments { - argumentTypes = append(argumentTypes, item.ArgDataType) + argumentTypes = append(argumentTypes, item.ArgDataTypeOld) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -483,9 +483,9 @@ func createPythonFunction(ctx context.Context, d *schema.ResourceData, meta inte func createJavascriptFunction(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) - id := sdk.NewSchemaObjectIdentifier(database, schema, name) + id := sdk.NewSchemaObjectIdentifier(database, sc, name) // Set required returns, diags := parseFunctionReturnsRequest(d.Get("return_type").(string)) @@ -522,9 +522,9 @@ func createJavascriptFunction(ctx context.Context, d *schema.ResourceData, meta } argumentTypes := make([]sdk.DataType, 0, len(arguments)) for _, item := range arguments { - argumentTypes = append(argumentTypes, item.ArgDataType) + argumentTypes = append(argumentTypes, item.ArgDataTypeOld) } - nid := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argumentTypes...) + nid := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argumentTypes...) d.SetId(nid.FullyQualifiedName()) return ReadContextFunction(ctx, d, meta) } @@ -575,7 +575,7 @@ func ReadContextFunction(ctx context.Context, d *schema.ResourceData, meta inter if value != "" { // Do nothing for functions without arguments pairs := strings.Split(value, ", ") - arguments := []interface{}{} + var arguments []interface{} for _, pair := range pairs { item := strings.Split(pair, " ") argument := map[string]interface{}{} @@ -739,7 +739,7 @@ func parseFunctionArguments(d *schema.ResourceData) ([]sdk.FunctionArgumentReque if diags != nil { return nil, diags } - args = append(args, sdk.FunctionArgumentRequest{ArgName: argName, ArgDataType: sdk.LegacyDataTypeFrom(argDataType)}) + args = append(args, sdk.FunctionArgumentRequest{ArgName: argName, ArgDataTypeOld: sdk.LegacyDataTypeFrom(argDataType)}) } } return args, nil @@ -764,8 +764,8 @@ func convertFunctionColumns(s string) ([]sdk.FunctionColumn, diag.Diagnostics) { return nil, diag.FromErr(err) } columns = append(columns, sdk.FunctionColumn{ - ColumnName: match[1], - ColumnDataType: sdk.LegacyDataTypeFrom(dataType), + ColumnName: match[1], + ColumnDataTypeOld: sdk.LegacyDataTypeFrom(dataType), }) } } @@ -781,7 +781,7 @@ func parseFunctionReturnsRequest(s string) (*sdk.FunctionReturnsRequest, diag.Di } var cr []sdk.FunctionColumnRequest for _, item := range columns { - cr = append(cr, *sdk.NewFunctionColumnRequest(item.ColumnName, item.ColumnDataType)) + cr = append(cr, *sdk.NewFunctionColumnRequest(item.ColumnName, nil).WithColumnDataTypeOld(item.ColumnDataTypeOld)) } returns.WithTable(*sdk.NewFunctionReturnsTableRequest().WithColumns(cr)) } else { @@ -789,7 +789,7 @@ func parseFunctionReturnsRequest(s string) (*sdk.FunctionReturnsRequest, diag.Di if diags != nil { return nil, diags } - returns.WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType))) + returns.WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.LegacyDataTypeFrom(returnDataType))) } return returns, nil } diff --git a/pkg/resources/procedure.go b/pkg/resources/procedure.go index aa8b557250..f7577833f9 100644 --- a/pkg/resources/procedure.go +++ b/pkg/resources/procedure.go @@ -243,7 +243,7 @@ func CreateContextProcedure(ctx context.Context, d *schema.ResourceData, meta in func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) args, diags := getProcedureArguments(d) if diags != nil { @@ -251,9 +251,9 @@ func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta inter } argDataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - argDataTypes[i] = arg.ArgDataType + argDataTypes[i] = arg.ArgDataTypeOld } - id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + id := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argDataTypes...) returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) if diags != nil { @@ -261,7 +261,7 @@ func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta inter } procedureDefinition := d.Get("statement").(string) runtimeVersion := d.Get("runtime_version").(string) - packages := []sdk.ProcedurePackageRequest{} + packages := make([]sdk.ProcedurePackageRequest, 0) for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) } @@ -287,7 +287,7 @@ func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta inter req.WithSecure(v.(bool)) } if _, ok := d.GetOk("imports"); ok { - imports := []sdk.ProcedureImportRequest{} + var imports []sdk.ProcedureImportRequest for _, item := range d.Get("imports").([]interface{}) { imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) } @@ -304,7 +304,7 @@ func createJavaProcedure(ctx context.Context, d *schema.ResourceData, meta inter func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) args, diags := getProcedureArguments(d) if diags != nil { @@ -312,9 +312,9 @@ func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta } argDataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - argDataTypes[i] = arg.ArgDataType + argDataTypes[i] = arg.ArgDataTypeOld } - id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + id := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argDataTypes...) returnType := d.Get("return_type").(string) returnDataType, diags := convertProcedureDataType(returnType) @@ -322,7 +322,7 @@ func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta return diags } procedureDefinition := d.Get("statement").(string) - req := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.LegacyDataTypeFrom(returnDataType), procedureDefinition) + req := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), nil, procedureDefinition).WithResultDataTypeOld(sdk.LegacyDataTypeFrom(returnDataType)) if len(args) > 0 { req.WithArguments(args) } @@ -355,7 +355,7 @@ func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) args, diags := getProcedureArguments(d) if diags != nil { @@ -363,9 +363,9 @@ func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta inte } argDataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - argDataTypes[i] = arg.ArgDataType + argDataTypes[i] = arg.ArgDataTypeOld } - id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + id := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argDataTypes...) returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) if diags != nil { @@ -373,7 +373,7 @@ func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta inte } procedureDefinition := d.Get("statement").(string) runtimeVersion := d.Get("runtime_version").(string) - packages := []sdk.ProcedurePackageRequest{} + packages := make([]sdk.ProcedurePackageRequest, 0) for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) } @@ -399,7 +399,7 @@ func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta inte req.WithSecure(v.(bool)) } if _, ok := d.GetOk("imports"); ok { - imports := []sdk.ProcedureImportRequest{} + var imports []sdk.ProcedureImportRequest for _, item := range d.Get("imports").([]interface{}) { imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) } @@ -416,7 +416,7 @@ func createScalaProcedure(ctx context.Context, d *schema.ResourceData, meta inte func createSQLProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) args, diags := getProcedureArguments(d) if diags != nil { @@ -424,9 +424,9 @@ func createSQLProcedure(ctx context.Context, d *schema.ResourceData, meta interf } argDataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - argDataTypes[i] = arg.ArgDataType + argDataTypes[i] = arg.ArgDataTypeOld } - id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + id := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argDataTypes...) returns, diags := parseProcedureSQLReturnsRequest(d.Get("return_type").(string)) if diags != nil { @@ -466,7 +466,7 @@ func createSQLProcedure(ctx context.Context, d *schema.ResourceData, meta interf func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client := meta.(*provider.Context).Client name := d.Get("name").(string) - schema := d.Get("schema").(string) + sc := d.Get("schema").(string) database := d.Get("database").(string) args, diags := getProcedureArguments(d) if diags != nil { @@ -474,9 +474,9 @@ func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta int } argDataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - argDataTypes[i] = arg.ArgDataType + argDataTypes[i] = arg.ArgDataTypeOld } - id := sdk.NewSchemaObjectIdentifierWithArguments(database, schema, name, argDataTypes...) + id := sdk.NewSchemaObjectIdentifierWithArguments(database, sc, name, argDataTypes...) returns, diags := parseProcedureReturnsRequest(d.Get("return_type").(string)) if diags != nil { @@ -484,7 +484,7 @@ func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta int } procedureDefinition := d.Get("statement").(string) runtimeVersion := d.Get("runtime_version").(string) - packages := []sdk.ProcedurePackageRequest{} + packages := make([]sdk.ProcedurePackageRequest, 0) for _, item := range d.Get("packages").([]interface{}) { packages = append(packages, *sdk.NewProcedurePackageRequest(item.(string))) } @@ -518,7 +518,7 @@ func createPythonProcedure(ctx context.Context, d *schema.ResourceData, meta int req.WithSecure(v.(bool)) } if _, ok := d.GetOk("imports"); ok { - imports := []sdk.ProcedureImportRequest{} + var imports []sdk.ProcedureImportRequest for _, item := range d.Get("imports").([]interface{}) { imports = append(imports, *sdk.NewProcedureImportRequest(item.(string))) } @@ -577,7 +577,7 @@ func ReadContextProcedure(ctx context.Context, d *schema.ResourceData, meta inte if args != "" { // Do nothing for functions without arguments argPairs := strings.Split(args, ", ") - args := []interface{}{} + var args []any for _, argPair := range argPairs { argItem := strings.Split(argPair, " ") @@ -735,7 +735,7 @@ func getProcedureArguments(d *schema.ResourceData) ([]sdk.ProcedureArgumentReque if diags != nil { return nil, diags } - args = append(args, sdk.ProcedureArgumentRequest{ArgName: argName, ArgDataType: sdk.LegacyDataTypeFrom(argDataType)}) + args = append(args, sdk.ProcedureArgumentRequest{ArgName: argName, ArgDataTypeOld: sdk.LegacyDataTypeFrom(argDataType)}) } } return args, nil @@ -760,8 +760,8 @@ func convertProcedureColumns(s string) ([]sdk.ProcedureColumn, diag.Diagnostics) return nil, diag.FromErr(err) } columns = append(columns, sdk.ProcedureColumn{ - ColumnName: match[1], - ColumnDataType: sdk.LegacyDataTypeFrom(dataType), + ColumnName: match[1], + ColumnDataTypeOld: sdk.LegacyDataTypeFrom(dataType), }) } } @@ -777,7 +777,7 @@ func parseProcedureReturnsRequest(s string) (*sdk.ProcedureReturnsRequest, diag. } var cr []sdk.ProcedureColumnRequest for _, item := range columns { - cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataType)) + cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, nil).WithColumnDataTypeOld(item.ColumnDataTypeOld)) } returns.WithTable(*sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) } else { @@ -785,7 +785,7 @@ func parseProcedureReturnsRequest(s string) (*sdk.ProcedureReturnsRequest, diag. if diags != nil { return nil, diags } - returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType))) + returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.LegacyDataTypeFrom(returnDataType))) } return returns, nil } @@ -799,7 +799,7 @@ func parseProcedureSQLReturnsRequest(s string) (*sdk.ProcedureSQLReturnsRequest, } var cr []sdk.ProcedureColumnRequest for _, item := range columns { - cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, item.ColumnDataType)) + cr = append(cr, *sdk.NewProcedureColumnRequest(item.ColumnName, nil).WithColumnDataTypeOld(item.ColumnDataTypeOld)) } returns.WithTable(*sdk.NewProcedureReturnsTableRequest().WithColumns(cr)) } else { @@ -807,7 +807,7 @@ func parseProcedureSQLReturnsRequest(s string) (*sdk.ProcedureSQLReturnsRequest, if diags != nil { return nil, diags } - returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType))) + returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.LegacyDataTypeFrom(returnDataType))) } return returns, nil } diff --git a/pkg/schemas/function_gen.go b/pkg/schemas/function_gen.go index ecd7d71f53..a211866ea5 100644 --- a/pkg/schemas/function_gen.go +++ b/pkg/schemas/function_gen.go @@ -95,7 +95,7 @@ func FunctionToSchema(function *sdk.Function) map[string]any { functionSchema["is_ansi"] = function.IsAnsi functionSchema["min_num_arguments"] = function.MinNumArguments functionSchema["max_num_arguments"] = function.MaxNumArguments - functionSchema["arguments"] = function.Arguments + functionSchema["arguments"] = function.ArgumentsOld functionSchema["arguments_raw"] = function.ArgumentsRaw functionSchema["description"] = function.Description functionSchema["catalog_name"] = function.CatalogName diff --git a/pkg/schemas/procedure_gen.go b/pkg/schemas/procedure_gen.go index 2499f043aa..38d5937273 100644 --- a/pkg/schemas/procedure_gen.go +++ b/pkg/schemas/procedure_gen.go @@ -83,7 +83,7 @@ func ProcedureToSchema(procedure *sdk.Procedure) map[string]any { procedureSchema["is_ansi"] = procedure.IsAnsi procedureSchema["min_num_arguments"] = procedure.MinNumArguments procedureSchema["max_num_arguments"] = procedure.MaxNumArguments - procedureSchema["arguments"] = procedure.Arguments + procedureSchema["arguments"] = procedure.ArgumentsOld procedureSchema["arguments_raw"] = procedure.ArgumentsRaw procedureSchema["description"] = procedure.Description procedureSchema["catalog_name"] = procedure.CatalogName diff --git a/pkg/sdk/datatypes/array.go b/pkg/sdk/datatypes/array.go index eb7247f6e6..835f48ae29 100644 --- a/pkg/sdk/datatypes/array.go +++ b/pkg/sdk/datatypes/array.go @@ -14,6 +14,10 @@ func (t *ArrayDataType) ToLegacyDataTypeSql() string { return ArrayLegacyDataType } +func (t *ArrayDataType) Canonical() string { + return ArrayLegacyDataType +} + var ArrayDataTypeSynonyms = []string{ArrayLegacyDataType} func parseArrayDataTypeRaw(raw sanitizedDataTypeRaw) (*ArrayDataType, error) { diff --git a/pkg/sdk/datatypes/binary.go b/pkg/sdk/datatypes/binary.go index c50dba0570..07181e3aaf 100644 --- a/pkg/sdk/datatypes/binary.go +++ b/pkg/sdk/datatypes/binary.go @@ -25,6 +25,10 @@ func (t *BinaryDataType) ToLegacyDataTypeSql() string { return BinaryLegacyDataType } +func (t *BinaryDataType) Canonical() string { + return fmt.Sprintf("%s(%d)", BinaryLegacyDataType, t.size) +} + var BinaryDataTypeSynonyms = []string{BinaryLegacyDataType, "VARBINARY"} func parseBinaryDataTypeRaw(raw sanitizedDataTypeRaw) (*BinaryDataType, error) { diff --git a/pkg/sdk/datatypes/boolean.go b/pkg/sdk/datatypes/boolean.go index 4e84979f40..56f84a4064 100644 --- a/pkg/sdk/datatypes/boolean.go +++ b/pkg/sdk/datatypes/boolean.go @@ -14,6 +14,10 @@ func (t *BooleanDataType) ToLegacyDataTypeSql() string { return BooleanLegacyDataType } +func (t *BooleanDataType) Canonical() string { + return BooleanLegacyDataType +} + var BooleanDataTypeSynonyms = []string{BooleanLegacyDataType} func parseBooleanDataTypeRaw(raw sanitizedDataTypeRaw) (*BooleanDataType, error) { diff --git a/pkg/sdk/datatypes/data_types.go b/pkg/sdk/datatypes/data_types.go index e1c0065855..be58f978f2 100644 --- a/pkg/sdk/datatypes/data_types.go +++ b/pkg/sdk/datatypes/data_types.go @@ -14,8 +14,12 @@ import ( // DataType is the common interface that represents all Snowflake datatypes documented in https://docs.snowflake.com/en/sql-reference/intro-summary-data-types. type DataType interface { + // ToSql formats data type explicitly specifying all arguments and using the given type (e.g. CHAR(29) for CHAR(29)). ToSql() string + // ToLegacyDataTypeSql formats data type using its base type without any attributes (e.g. VARCHAR for CHAR(29)). ToLegacyDataTypeSql() string + // Canonical formats the data type between ToSql and ToLegacyDataTypeSql: it uses base type but with arguments (e.g. VARCHAR(29) for CHAR(29)). + Canonical() string } type sanitizedDataTypeRaw struct { diff --git a/pkg/sdk/datatypes/data_types_test.go b/pkg/sdk/datatypes/data_types_test.go index 21525fded8..cfb3845ef1 100644 --- a/pkg/sdk/datatypes/data_types_test.go +++ b/pkg/sdk/datatypes/data_types_test.go @@ -2,6 +2,7 @@ package datatypes import ( "fmt" + "slices" "strings" "testing" @@ -91,7 +92,12 @@ func Test_ParseDataType_Number(t *testing.T) { assert.Equal(t, tc.expectedUnderlyingType, parsed.(*NumberDataType).underlyingType) assert.Equal(t, NumberLegacyDataType, parsed.ToLegacyDataTypeSql()) - assert.Equal(t, fmt.Sprintf("%s(%d, %d)", parsed.(*NumberDataType).underlyingType, parsed.(*NumberDataType).precision, parsed.(*NumberDataType).scale), parsed.ToSql()) + if slices.Contains(NumberDataTypeSubTypes, parsed.(*NumberDataType).underlyingType) { + assert.Equal(t, parsed.(*NumberDataType).underlyingType, parsed.ToSql()) + } else { + assert.Equal(t, fmt.Sprintf("%s(%d, %d)", parsed.(*NumberDataType).underlyingType, parsed.(*NumberDataType).precision, parsed.(*NumberDataType).scale), parsed.ToSql()) + } + assert.Equal(t, fmt.Sprintf("%s(%d,%d)", NumberLegacyDataType, parsed.(*NumberDataType).precision, parsed.(*NumberDataType).scale), parsed.Canonical()) }) } @@ -158,6 +164,7 @@ func Test_ParseDataType_Float(t *testing.T) { assert.Equal(t, FloatLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, FloatLegacyDataType, parsed.Canonical()) }) } @@ -267,6 +274,7 @@ func Test_ParseDataType_Text(t *testing.T) { assert.Equal(t, VarcharLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TextDataType).underlyingType, parsed.(*TextDataType).length), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", VarcharLegacyDataType, parsed.(*TextDataType).length), parsed.Canonical()) }) } @@ -338,6 +346,7 @@ func Test_ParseDataType_Binary(t *testing.T) { assert.Equal(t, BinaryLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*BinaryDataType).underlyingType, parsed.(*BinaryDataType).size), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", BinaryLegacyDataType, parsed.(*BinaryDataType).size), parsed.Canonical()) }) } @@ -396,6 +405,7 @@ func Test_ParseDataType_Boolean(t *testing.T) { assert.Equal(t, BooleanLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, BooleanLegacyDataType, parsed.Canonical()) }) } @@ -452,6 +462,7 @@ func Test_ParseDataType_Date(t *testing.T) { assert.Equal(t, DateLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, DateLegacyDataType, parsed.Canonical()) }) } @@ -512,6 +523,7 @@ func Test_ParseDataType_Time(t *testing.T) { assert.Equal(t, TimeLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%d)", tc.expectedUnderlyingType, tc.expectedPrecision), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", TimeLegacyDataType, tc.expectedPrecision), parsed.Canonical()) }) } @@ -581,6 +593,7 @@ func Test_ParseDataType_TimestampLtz(t *testing.T) { assert.Equal(t, TimestampLtzLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TimestampLtzDataType).underlyingType, parsed.(*TimestampLtzDataType).precision), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", TimestampLtzLegacyDataType, parsed.(*TimestampLtzDataType).precision), parsed.Canonical()) }) } @@ -652,6 +665,7 @@ func Test_ParseDataType_TimestampNtz(t *testing.T) { assert.Equal(t, TimestampNtzLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TimestampNtzDataType).underlyingType, parsed.(*TimestampNtzDataType).precision), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", TimestampNtzLegacyDataType, parsed.(*TimestampNtzDataType).precision), parsed.Canonical()) }) } @@ -721,6 +735,7 @@ func Test_ParseDataType_TimestampTz(t *testing.T) { assert.Equal(t, TimestampTzLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TimestampTzDataType).underlyingType, parsed.(*TimestampTzDataType).precision), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", TimestampTzLegacyDataType, parsed.(*TimestampTzDataType).precision), parsed.Canonical()) }) } @@ -777,6 +792,7 @@ func Test_ParseDataType_Variant(t *testing.T) { assert.Equal(t, VariantLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, VariantLegacyDataType, parsed.Canonical()) }) } @@ -833,6 +849,7 @@ func Test_ParseDataType_Object(t *testing.T) { assert.Equal(t, ObjectLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, ObjectLegacyDataType, parsed.Canonical()) }) } @@ -889,6 +906,7 @@ func Test_ParseDataType_Array(t *testing.T) { assert.Equal(t, ArrayLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, ArrayLegacyDataType, parsed.Canonical()) }) } @@ -945,6 +963,7 @@ func Test_ParseDataType_Geography(t *testing.T) { assert.Equal(t, GeographyLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, GeographyLegacyDataType, parsed.Canonical()) }) } @@ -1001,6 +1020,7 @@ func Test_ParseDataType_Geometry(t *testing.T) { assert.Equal(t, GeometryLegacyDataType, parsed.ToLegacyDataTypeSql()) assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + assert.Equal(t, GeometryLegacyDataType, parsed.Canonical()) }) } @@ -1060,6 +1080,7 @@ func Test_ParseDataType_Vector(t *testing.T) { assert.Equal(t, fmt.Sprintf("%s(%s, %d)", parsed.(*VectorDataType).underlyingType, parsed.(*VectorDataType).innerType, parsed.(*VectorDataType).dimension), parsed.ToLegacyDataTypeSql()) assert.Equal(t, fmt.Sprintf("%s(%s, %d)", parsed.(*VectorDataType).underlyingType, parsed.(*VectorDataType).innerType, parsed.(*VectorDataType).dimension), parsed.ToSql()) + assert.Equal(t, fmt.Sprintf("%s(%s, %d)", parsed.(*VectorDataType).underlyingType, parsed.(*VectorDataType).innerType, parsed.(*VectorDataType).dimension), parsed.Canonical()) }) } diff --git a/pkg/sdk/datatypes/date.go b/pkg/sdk/datatypes/date.go index 92ee7c27bc..c962a4a831 100644 --- a/pkg/sdk/datatypes/date.go +++ b/pkg/sdk/datatypes/date.go @@ -14,6 +14,10 @@ func (t *DateDataType) ToLegacyDataTypeSql() string { return DateLegacyDataType } +func (t *DateDataType) Canonical() string { + return DateLegacyDataType +} + var DateDataTypeSynonyms = []string{DateLegacyDataType} func parseDateDataTypeRaw(raw sanitizedDataTypeRaw) (*DateDataType, error) { diff --git a/pkg/sdk/datatypes/float.go b/pkg/sdk/datatypes/float.go index a0ca84863b..36fe0d9be0 100644 --- a/pkg/sdk/datatypes/float.go +++ b/pkg/sdk/datatypes/float.go @@ -14,6 +14,10 @@ func (t *FloatDataType) ToLegacyDataTypeSql() string { return FloatLegacyDataType } +func (t *FloatDataType) Canonical() string { + return FloatLegacyDataType +} + var FloatDataTypeSynonyms = []string{"FLOAT8", "FLOAT4", FloatLegacyDataType, "DOUBLE PRECISION", "DOUBLE", "REAL"} func parseFloatDataTypeRaw(raw sanitizedDataTypeRaw) (*FloatDataType, error) { diff --git a/pkg/sdk/datatypes/geography.go b/pkg/sdk/datatypes/geography.go index 4a024a20b0..43ee148212 100644 --- a/pkg/sdk/datatypes/geography.go +++ b/pkg/sdk/datatypes/geography.go @@ -14,6 +14,10 @@ func (t *GeographyDataType) ToLegacyDataTypeSql() string { return GeographyLegacyDataType } +func (t *GeographyDataType) Canonical() string { + return GeographyLegacyDataType +} + var GeographyDataTypeSynonyms = []string{GeographyLegacyDataType} func parseGeographyDataTypeRaw(raw sanitizedDataTypeRaw) (*GeographyDataType, error) { diff --git a/pkg/sdk/datatypes/geometry.go b/pkg/sdk/datatypes/geometry.go index d09ebd3eea..8ab62e817b 100644 --- a/pkg/sdk/datatypes/geometry.go +++ b/pkg/sdk/datatypes/geometry.go @@ -14,6 +14,10 @@ func (t *GeometryDataType) ToLegacyDataTypeSql() string { return GeometryLegacyDataType } +func (t *GeometryDataType) Canonical() string { + return GeometryLegacyDataType +} + var GeometryDataTypeSynonyms = []string{GeometryLegacyDataType} func parseGeometryDataTypeRaw(raw sanitizedDataTypeRaw) (*GeometryDataType, error) { diff --git a/pkg/sdk/datatypes/number.go b/pkg/sdk/datatypes/number.go index 14ac2696fc..cd11467717 100644 --- a/pkg/sdk/datatypes/number.go +++ b/pkg/sdk/datatypes/number.go @@ -24,6 +24,9 @@ type NumberDataType struct { } func (t *NumberDataType) ToSql() string { + if slices.Contains(NumberDataTypeSubTypes, t.underlyingType) { + return t.underlyingType + } return fmt.Sprintf("%s(%d, %d)", t.underlyingType, t.precision, t.scale) } @@ -31,6 +34,10 @@ func (t *NumberDataType) ToLegacyDataTypeSql() string { return NumberLegacyDataType } +func (t *NumberDataType) Canonical() string { + return fmt.Sprintf("%s(%d,%d)", NumberLegacyDataType, t.precision, t.scale) +} + var ( NumberDataTypeSynonyms = []string{NumberLegacyDataType, "DECIMAL", "DEC", "NUMERIC"} NumberDataTypeSubTypes = []string{"INTEGER", "INT", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT"} diff --git a/pkg/sdk/datatypes/object.go b/pkg/sdk/datatypes/object.go index fe333aa7b0..098b04b0be 100644 --- a/pkg/sdk/datatypes/object.go +++ b/pkg/sdk/datatypes/object.go @@ -14,6 +14,10 @@ func (t *ObjectDataType) ToLegacyDataTypeSql() string { return ObjectLegacyDataType } +func (t *ObjectDataType) Canonical() string { + return ObjectLegacyDataType +} + var ObjectDataTypeSynonyms = []string{ObjectLegacyDataType} func parseObjectDataTypeRaw(raw sanitizedDataTypeRaw) (*ObjectDataType, error) { diff --git a/pkg/sdk/datatypes/text.go b/pkg/sdk/datatypes/text.go index 2598253101..c05d64f18c 100644 --- a/pkg/sdk/datatypes/text.go +++ b/pkg/sdk/datatypes/text.go @@ -30,6 +30,10 @@ func (t *TextDataType) ToLegacyDataTypeSql() string { return VarcharLegacyDataType } +func (t *TextDataType) Canonical() string { + return fmt.Sprintf("%s(%d)", VarcharLegacyDataType, t.length) +} + var ( TextDataTypeSynonyms = []string{VarcharLegacyDataType, "STRING", "TEXT", "NVARCHAR2", "NVARCHAR", "CHAR VARYING", "NCHAR VARYING"} TextDataTypeSubtypes = []string{"CHARACTER", "CHAR", "NCHAR"} diff --git a/pkg/sdk/datatypes/time.go b/pkg/sdk/datatypes/time.go index ee79421122..e33223c104 100644 --- a/pkg/sdk/datatypes/time.go +++ b/pkg/sdk/datatypes/time.go @@ -25,6 +25,10 @@ func (t *TimeDataType) ToLegacyDataTypeSql() string { return TimeLegacyDataType } +func (t *TimeDataType) Canonical() string { + return fmt.Sprintf("%s(%d)", TimeLegacyDataType, t.precision) +} + var TimeDataTypeSynonyms = []string{TimeLegacyDataType} func parseTimeDataTypeRaw(raw sanitizedDataTypeRaw) (*TimeDataType, error) { diff --git a/pkg/sdk/datatypes/timestamp_ltz.go b/pkg/sdk/datatypes/timestamp_ltz.go index f844ec537f..41961bfdb7 100644 --- a/pkg/sdk/datatypes/timestamp_ltz.go +++ b/pkg/sdk/datatypes/timestamp_ltz.go @@ -23,6 +23,10 @@ func (t *TimestampLtzDataType) ToLegacyDataTypeSql() string { return TimestampLtzLegacyDataType } +func (t *TimestampLtzDataType) Canonical() string { + return fmt.Sprintf("%s(%d)", TimestampLtzLegacyDataType, t.precision) +} + var TimestampLtzDataTypeSynonyms = []string{TimestampLtzLegacyDataType, "TIMESTAMPLTZ", "TIMESTAMP WITH LOCAL TIME ZONE"} func parseTimestampLtzDataTypeRaw(raw sanitizedDataTypeRaw) (*TimestampLtzDataType, error) { diff --git a/pkg/sdk/datatypes/timestamp_ntz.go b/pkg/sdk/datatypes/timestamp_ntz.go index 86aa5f0a0c..e11ed41b08 100644 --- a/pkg/sdk/datatypes/timestamp_ntz.go +++ b/pkg/sdk/datatypes/timestamp_ntz.go @@ -23,6 +23,10 @@ func (t *TimestampNtzDataType) ToLegacyDataTypeSql() string { return TimestampNtzLegacyDataType } +func (t *TimestampNtzDataType) Canonical() string { + return fmt.Sprintf("%s(%d)", TimestampNtzLegacyDataType, t.precision) +} + var TimestampNtzDataTypeSynonyms = []string{TimestampNtzLegacyDataType, "TIMESTAMPNTZ", "TIMESTAMP WITHOUT TIME ZONE", "DATETIME"} func parseTimestampNtzDataTypeRaw(raw sanitizedDataTypeRaw) (*TimestampNtzDataType, error) { diff --git a/pkg/sdk/datatypes/timestamp_tz.go b/pkg/sdk/datatypes/timestamp_tz.go index 44e6cafeb6..0c99944bf8 100644 --- a/pkg/sdk/datatypes/timestamp_tz.go +++ b/pkg/sdk/datatypes/timestamp_tz.go @@ -23,6 +23,10 @@ func (t *TimestampTzDataType) ToLegacyDataTypeSql() string { return TimestampTzLegacyDataType } +func (t *TimestampTzDataType) Canonical() string { + return fmt.Sprintf("%s(%d)", TimestampTzLegacyDataType, t.precision) +} + var TimestampTzDataTypeSynonyms = []string{TimestampTzLegacyDataType, "TIMESTAMPTZ", "TIMESTAMP WITH TIME ZONE"} func parseTimestampTzDataTypeRaw(raw sanitizedDataTypeRaw) (*TimestampTzDataType, error) { diff --git a/pkg/sdk/datatypes/variant.go b/pkg/sdk/datatypes/variant.go index b096084934..538ca2921d 100644 --- a/pkg/sdk/datatypes/variant.go +++ b/pkg/sdk/datatypes/variant.go @@ -14,6 +14,10 @@ func (t *VariantDataType) ToLegacyDataTypeSql() string { return VariantLegacyDataType } +func (t *VariantDataType) Canonical() string { + return VariantLegacyDataType +} + var VariantDataTypeSynonyms = []string{VariantLegacyDataType} func parseVariantDataTypeRaw(raw sanitizedDataTypeRaw) (*VariantDataType, error) { diff --git a/pkg/sdk/datatypes/vector.go b/pkg/sdk/datatypes/vector.go index a535ca2b58..035249af64 100644 --- a/pkg/sdk/datatypes/vector.go +++ b/pkg/sdk/datatypes/vector.go @@ -26,6 +26,10 @@ func (t *VectorDataType) ToLegacyDataTypeSql() string { return t.ToSql() } +func (t *VectorDataType) Canonical() string { + return t.ToSql() +} + var ( VectorDataTypeSynonyms = []string{"VECTOR"} VectorAllowedInnerTypes = []string{"INT", "FLOAT"} diff --git a/pkg/sdk/functions_def.go b/pkg/sdk/functions_def.go index 017ccd8335..825c1d2551 100644 --- a/pkg/sdk/functions_def.go +++ b/pkg/sdk/functions_def.go @@ -6,18 +6,24 @@ import g "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/gen var functionArgument = g.NewQueryStruct("FunctionArgument"). Text("ArgName", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("ArgDataType", "DataType", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("DefaultValue", "*string", g.ParameterOptions().NoEquals().SQL("DEFAULT")) + PredefinedQueryStructField("ArgDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ArgDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). + PredefinedQueryStructField("DefaultValue", "*string", g.ParameterOptions().NoEquals().SQL("DEFAULT")). + WithValidation(g.ExactlyOneValueSet, "ArgDataTypeOld", "ArgDataType") var functionColumn = g.NewQueryStruct("FunctionColumn"). Text("ColumnName", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("ColumnDataType", "DataType", g.KeywordOptions().NoQuotes().Required()) + PredefinedQueryStructField("ColumnDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ColumnDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). + WithValidation(g.ExactlyOneValueSet, "ColumnDataTypeOld", "ColumnDataType") var functionReturns = g.NewQueryStruct("FunctionReturns"). OptionalQueryStructField( "ResultDataType", g.NewQueryStruct("FunctionReturnsResultDataType"). - PredefinedQueryStructField("ResultDataType", "DataType", g.KeywordOptions().NoQuotes().Required()), + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). + WithValidation(g.ExactlyOneValueSet, "ResultDataTypeOld", "ResultDataType"), g.KeywordOptions(), ). OptionalQueryStructField( @@ -174,7 +180,9 @@ var FunctionsDef = g.NewInterface( functionArgument, g.ListOptions().MustParentheses()). OptionalSQL("COPY GRANTS"). - PredefinedQueryStructField("ResultDataType", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). + SQL("RETURNS"). + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals()). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). PredefinedQueryStructField("ReturnNullValues", "*ReturnNullValues", g.KeywordOptions()). SQL("LANGUAGE SCALA"). PredefinedQueryStructField("NullInputBehavior", "*NullInputBehavior", g.KeywordOptions()). @@ -196,7 +204,8 @@ var FunctionsDef = g.NewInterface( PredefinedQueryStructField("FunctionDefinition", "*string", g.ParameterOptions().NoEquals().SingleQuotes().SQL("AS")). WithValidation(g.ValidIdentifier, "name"). WithValidation(g.ValidateValueSet, "Handler"). - WithValidation(g.ConflictingFields, "OrReplace", "IfNotExists"), + WithValidation(g.ConflictingFields, "OrReplace", "IfNotExists"). + WithValidation(g.ExactlyOneValueSet, "ResultDataTypeOld", "ResultDataType"), ).CustomOperation( "CreateForSQL", "https://docs.snowflake.com/en/sql-reference/sql/create-function#sql-handler", @@ -282,7 +291,7 @@ var FunctionsDef = g.NewInterface( Field("IsAnsi", "bool"). Field("MinNumArguments", "int"). Field("MaxNumArguments", "int"). - Field("Arguments", "string"). + Field("ArgumentsRaw", "string"). Field("Description", "string"). Field("CatalogName", "string"). Field("IsTableFunction", "bool"). diff --git a/pkg/sdk/functions_dto_builders_gen.go b/pkg/sdk/functions_dto_builders_gen.go index 0aef014932..3bb40dfd0e 100644 --- a/pkg/sdk/functions_dto_builders_gen.go +++ b/pkg/sdk/functions_dto_builders_gen.go @@ -2,7 +2,10 @@ package sdk -import () +// imports edited manually +import ( + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" +) func NewCreateForJavaFunctionRequest( name SchemaObjectIdentifier, @@ -103,7 +106,7 @@ func (s *CreateForJavaFunctionRequest) WithFunctionDefinition(FunctionDefinition func NewFunctionArgumentRequest( ArgName string, - ArgDataType DataType, + ArgDataType datatypes.DataType, ) *FunctionArgumentRequest { s := FunctionArgumentRequest{} s.ArgName = ArgName @@ -111,6 +114,11 @@ func NewFunctionArgumentRequest( return &s } +func (s *FunctionArgumentRequest) WithArgDataTypeOld(ArgDataTypeOld DataType) *FunctionArgumentRequest { + s.ArgDataTypeOld = ArgDataTypeOld + return s +} + func (s *FunctionArgumentRequest) WithDefaultValue(DefaultValue string) *FunctionArgumentRequest { s.DefaultValue = &DefaultValue return s @@ -131,13 +139,18 @@ func (s *FunctionReturnsRequest) WithTable(Table FunctionReturnsTableRequest) *F } func NewFunctionReturnsResultDataTypeRequest( - ResultDataType DataType, + ResultDataType datatypes.DataType, ) *FunctionReturnsResultDataTypeRequest { s := FunctionReturnsResultDataTypeRequest{} s.ResultDataType = ResultDataType return &s } +func (s *FunctionReturnsResultDataTypeRequest) WithResultDataTypeOld(ResultDataTypeOld DataType) *FunctionReturnsResultDataTypeRequest { + s.ResultDataTypeOld = ResultDataTypeOld + return s +} + func NewFunctionReturnsTableRequest() *FunctionReturnsTableRequest { return &FunctionReturnsTableRequest{} } @@ -149,7 +162,7 @@ func (s *FunctionReturnsTableRequest) WithColumns(Columns []FunctionColumnReques func NewFunctionColumnRequest( ColumnName string, - ColumnDataType DataType, + ColumnDataType datatypes.DataType, ) *FunctionColumnRequest { s := FunctionColumnRequest{} s.ColumnName = ColumnName @@ -157,6 +170,11 @@ func NewFunctionColumnRequest( return &s } +func (s *FunctionColumnRequest) WithColumnDataTypeOld(ColumnDataTypeOld DataType) *FunctionColumnRequest { + s.ColumnDataTypeOld = ColumnDataTypeOld + return s +} + func NewFunctionImportRequest() *FunctionImportRequest { return &FunctionImportRequest{} } @@ -323,7 +341,7 @@ func (s *CreateForPythonFunctionRequest) WithFunctionDefinition(FunctionDefiniti func NewCreateForScalaFunctionRequest( name SchemaObjectIdentifier, - ResultDataType DataType, + ResultDataType datatypes.DataType, Handler string, ) *CreateForScalaFunctionRequest { s := CreateForScalaFunctionRequest{} @@ -363,6 +381,11 @@ func (s *CreateForScalaFunctionRequest) WithCopyGrants(CopyGrants bool) *CreateF return s } +func (s *CreateForScalaFunctionRequest) WithResultDataTypeOld(ResultDataTypeOld DataType) *CreateForScalaFunctionRequest { + s.ResultDataTypeOld = ResultDataTypeOld + return s +} + func (s *CreateForScalaFunctionRequest) WithReturnNullValues(ReturnNullValues ReturnNullValues) *CreateForScalaFunctionRequest { s.ReturnNullValues = &ReturnNullValues return s diff --git a/pkg/sdk/functions_dto_gen.go b/pkg/sdk/functions_dto_gen.go index 3fa0ff387b..4ff74dcd73 100644 --- a/pkg/sdk/functions_dto_gen.go +++ b/pkg/sdk/functions_dto_gen.go @@ -1,5 +1,7 @@ package sdk +import "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" + //go:generate go run ./dto-builder-generator/main.go var ( @@ -38,9 +40,10 @@ type CreateForJavaFunctionRequest struct { } type FunctionArgumentRequest struct { - ArgName string // required - ArgDataType DataType // required - DefaultValue *string + ArgName string // required + ArgDataTypeOld DataType + ArgDataType datatypes.DataType // required + DefaultValue *string } type FunctionReturnsRequest struct { @@ -49,7 +52,8 @@ type FunctionReturnsRequest struct { } type FunctionReturnsResultDataTypeRequest struct { - ResultDataType DataType // required + ResultDataTypeOld DataType + ResultDataType datatypes.DataType // required } type FunctionReturnsTableRequest struct { @@ -57,8 +61,9 @@ type FunctionReturnsTableRequest struct { } type FunctionColumnRequest struct { - ColumnName string // required - ColumnDataType DataType // required + ColumnName string // required + ColumnDataTypeOld DataType + ColumnDataType datatypes.DataType // required } type FunctionImportRequest struct { @@ -114,7 +119,8 @@ type CreateForScalaFunctionRequest struct { name SchemaObjectIdentifier // required Arguments []FunctionArgumentRequest CopyGrants *bool - ResultDataType DataType // required + ResultDataTypeOld DataType + ResultDataType datatypes.DataType // required ReturnNullValues *ReturnNullValues NullInputBehavior *NullInputBehavior ReturnResultsBehavior *ReturnResultsBehavior diff --git a/pkg/sdk/functions_ext.go b/pkg/sdk/functions_ext.go new file mode 100644 index 0000000000..4fe8a9524d --- /dev/null +++ b/pkg/sdk/functions_ext.go @@ -0,0 +1,5 @@ +package sdk + +func (v *Function) ID() SchemaObjectIdentifierWithArguments { + return NewSchemaObjectIdentifierWithArguments(v.CatalogName, v.SchemaName, v.Name, v.ArgumentsOld...) +} diff --git a/pkg/sdk/functions_gen.go b/pkg/sdk/functions_gen.go index 85f2b8378d..ab7ca62170 100644 --- a/pkg/sdk/functions_gen.go +++ b/pkg/sdk/functions_gen.go @@ -1,8 +1,11 @@ package sdk +// imports edited manually import ( "context" "database/sql" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type Functions interface { @@ -46,9 +49,10 @@ type CreateForJavaFunctionOptions struct { } type FunctionArgument struct { - ArgName string `ddl:"keyword,no_quotes"` - ArgDataType DataType `ddl:"keyword,no_quotes"` - DefaultValue *string `ddl:"parameter,no_equals" sql:"DEFAULT"` + ArgName string `ddl:"keyword,no_quotes"` + ArgDataTypeOld DataType `ddl:"keyword,no_quotes"` + ArgDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` + DefaultValue *string `ddl:"parameter,no_equals" sql:"DEFAULT"` } type FunctionReturns struct { @@ -57,7 +61,8 @@ type FunctionReturns struct { } type FunctionReturnsResultDataType struct { - ResultDataType DataType `ddl:"keyword,no_quotes"` + ResultDataTypeOld DataType `ddl:"keyword,no_quotes"` + ResultDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` } type FunctionReturnsTable struct { @@ -65,8 +70,9 @@ type FunctionReturnsTable struct { } type FunctionColumn struct { - ColumnName string `ddl:"keyword,no_quotes"` - ColumnDataType DataType `ddl:"keyword,no_quotes"` + ColumnName string `ddl:"keyword,no_quotes"` + ColumnDataTypeOld DataType `ddl:"keyword,no_quotes"` + ColumnDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` } type FunctionImport struct { @@ -133,7 +139,9 @@ type CreateForScalaFunctionOptions struct { name SchemaObjectIdentifier `ddl:"identifier"` Arguments []FunctionArgument `ddl:"list,must_parentheses"` CopyGrants *bool `ddl:"keyword" sql:"COPY GRANTS"` - ResultDataType DataType `ddl:"parameter,no_equals" sql:"RETURNS"` + returns bool `ddl:"static" sql:"RETURNS"` + ResultDataTypeOld DataType `ddl:"parameter,no_equals"` + ResultDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` ReturnNullValues *ReturnNullValues `ddl:"keyword"` languageScala bool `ddl:"static" sql:"LANGUAGE SCALA"` NullInputBehavior *NullInputBehavior `ddl:"keyword"` @@ -229,7 +237,7 @@ type Function struct { IsAnsi bool MinNumArguments int MaxNumArguments int - Arguments []DataType + ArgumentsOld []DataType ArgumentsRaw string Description string CatalogName string @@ -241,10 +249,6 @@ type Function struct { IsMemoizable bool } -func (v *Function) ID() SchemaObjectIdentifierWithArguments { - return NewSchemaObjectIdentifierWithArguments(v.CatalogName, v.SchemaName, v.Name, v.Arguments...) -} - // DescribeFunctionOptions is based on https://docs.snowflake.com/en/sql-reference/sql/desc-function. type DescribeFunctionOptions struct { describe bool `ddl:"static" sql:"DESCRIBE"` diff --git a/pkg/sdk/functions_gen_test.go b/pkg/sdk/functions_gen_test.go index b0c1c5b0b5..95c21d9204 100644 --- a/pkg/sdk/functions_gen_test.go +++ b/pkg/sdk/functions_gen_test.go @@ -24,12 +24,93 @@ func TestFunctions_CreateForJava(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + t.Run("validation: returns", func(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{} assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaFunctionOptions.Returns", "ResultDataType", "Table")) }) + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + t.Run("validation: function definition", func(t *testing.T) { opts := defaultOpts() opts.TargetPath = String("@~/testfunc.jar") @@ -47,12 +128,78 @@ func TestFunctions_CreateForJava(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaFunctionOptions", "Handler")) }) + // TODO [SNOW-1348103]: remove with old function removal for V1 + t.Run("all options - old data types", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Temporary = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []FunctionArgument{ + { + ArgName: "id", + ArgDataTypeOld: DataTypeNumber, + }, + { + ArgName: "name", + ArgDataTypeOld: DataTypeVARCHAR, + DefaultValue: String("'test'"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + { + ColumnName: "country_code", + ColumnDataTypeOld: DataTypeVARCHAR, + }, + { + ColumnName: "country_name", + ColumnDataTypeOld: DataTypeVARCHAR, + }, + }, + }, + } + opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorCalledOnNullInput) + opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) + opts.RuntimeVersion = String("2.0") + opts.Comment = String("comment") + opts.Imports = []FunctionImport{ + { + Import: "@~/my_decrement_udf_package_dir/my_decrement_udf_jar.jar", + }, + } + opts.Packages = []FunctionPackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + } + opts.Handler = "TestFunc.echoVarchar" + opts.ExternalAccessIntegrations = []AccountObjectIdentifier{ + NewAccountObjectIdentifier("ext_integration"), + } + opts.Secrets = []SecretReference{ + { + VariableName: "variable1", + Name: "name1", + }, + { + VariableName: "variable2", + Name: "name2", + }, + } + opts.TargetPath = String("@~/testfunc.jar") + opts.FunctionDefinition = String("return id + name;") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (id NUMBER, name VARCHAR DEFAULT 'test') COPY GRANTS RETURNS TABLE (country_code VARCHAR, country_name VARCHAR) NOT NULL LANGUAGE JAVA CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '2.0' COMMENT = 'comment' IMPORTS = ('@~/my_decrement_udf_package_dir/my_decrement_udf_jar.jar') PACKAGES = ('com.snowflake:snowpark:1.2.0') HANDLER = 'TestFunc.echoVarchar' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) TARGET_PATH = '@~/testfunc.jar' AS 'return id + name;'`, id.FullyQualifiedName()) + }) + t.Run("all options", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -61,11 +208,11 @@ func TestFunctions_CreateForJava(t *testing.T) { opts.Arguments = []FunctionArgument{ { ArgName: "id", - ArgDataType: DataTypeNumber, + ArgDataType: dataTypeNumber, }, { ArgName: "name", - ArgDataType: DataTypeVARCHAR, + ArgDataType: dataTypeVarchar, DefaultValue: String("'test'"), }, } @@ -75,11 +222,11 @@ func TestFunctions_CreateForJava(t *testing.T) { Columns: []FunctionColumn{ { ColumnName: "country_code", - ColumnDataType: DataTypeVARCHAR, + ColumnDataType: dataTypeVarchar, }, { ColumnName: "country_name", - ColumnDataType: DataTypeVARCHAR, + ColumnDataType: dataTypeVarchar, }, }, }, @@ -115,7 +262,7 @@ func TestFunctions_CreateForJava(t *testing.T) { } opts.TargetPath = String("@~/testfunc.jar") opts.FunctionDefinition = String("return id + name;") - assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (id NUMBER, name VARCHAR DEFAULT 'test') COPY GRANTS RETURNS TABLE (country_code VARCHAR, country_name VARCHAR) NOT NULL LANGUAGE JAVA CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '2.0' COMMENT = 'comment' IMPORTS = ('@~/my_decrement_udf_package_dir/my_decrement_udf_jar.jar') PACKAGES = ('com.snowflake:snowpark:1.2.0') HANDLER = 'TestFunc.echoVarchar' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) TARGET_PATH = '@~/testfunc.jar' AS 'return id + name;'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (id NUMBER(36, 2), name VARCHAR(100) DEFAULT 'test') COPY GRANTS RETURNS TABLE (country_code VARCHAR(100), country_name VARCHAR(100)) NOT NULL LANGUAGE JAVA CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '2.0' COMMENT = 'comment' IMPORTS = ('@~/my_decrement_udf_package_dir/my_decrement_udf_jar.jar') PACKAGES = ('com.snowflake:snowpark:1.2.0') HANDLER = 'TestFunc.echoVarchar' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) TARGET_PATH = '@~/testfunc.jar' AS 'return id + name;'`, id.FullyQualifiedName()) }) } @@ -139,6 +286,87 @@ func TestFunctions_CreateForJavascript(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavascriptFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavascriptFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavascriptFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + t.Run("validation: returns", func(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{} @@ -149,12 +377,39 @@ func TestFunctions_CreateForJavascript(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavascriptFunctionOptions", "FunctionDefinition")) }) + // TODO [SNOW-1348103]: remove with old function removal for V1 + t.Run("all options - old data types", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Temporary = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []FunctionArgument{ + { + ArgName: "d", + ArgDataTypeOld: DataTypeFloat, + DefaultValue: String("1.0"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + }, + } + opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorCalledOnNullInput) + opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) + opts.Comment = String("comment") + opts.FunctionDefinition = "return 1;" + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (d FLOAT DEFAULT 1.0) COPY GRANTS RETURNS FLOAT NOT NULL LANGUAGE JAVASCRIPT CALLED ON NULL INPUT IMMUTABLE COMMENT = 'comment' AS 'return 1;'`, id.FullyQualifiedName()) + }) + t.Run("all options", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -163,14 +418,14 @@ func TestFunctions_CreateForJavascript(t *testing.T) { opts.Arguments = []FunctionArgument{ { ArgName: "d", - ArgDataType: DataTypeFloat, + ArgDataType: dataTypeFloat, DefaultValue: String("1.0"), }, } opts.CopyGrants = Bool(true) opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) @@ -202,6 +457,87 @@ func TestFunctions_CreateForPython(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one correct, one incorrect", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + t.Run("validation: returns", func(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{} @@ -212,7 +548,7 @@ func TestFunctions_CreateForPython(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonFunctionOptions", "RuntimeVersion")) @@ -229,6 +565,64 @@ func TestFunctions_CreateForPython(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, NewError("IMPORTS must not be empty when AS is nil")) }) + // TODO [SNOW-1348103]: remove with old function removal for V1 + t.Run("all options - old data types", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Temporary = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []FunctionArgument{ + { + ArgName: "i", + ArgDataTypeOld: DataTypeNumber, + DefaultValue: String("1"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataTypeOld: DataTypeVariant, + }, + } + opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorCalledOnNullInput) + opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) + opts.RuntimeVersion = "3.8" + opts.Comment = String("comment") + opts.Imports = []FunctionImport{ + { + Import: "numpy", + }, + { + Import: "pandas", + }, + } + opts.Packages = []FunctionPackage{ + { + Package: "numpy", + }, + { + Package: "pandas", + }, + } + opts.Handler = "udf" + opts.ExternalAccessIntegrations = []AccountObjectIdentifier{ + NewAccountObjectIdentifier("ext_integration"), + } + opts.Secrets = []SecretReference{ + { + VariableName: "variable1", + Name: "name1", + }, + { + VariableName: "variable2", + Name: "name2", + }, + } + opts.FunctionDefinition = String("import numpy as np") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (i NUMBER DEFAULT 1) COPY GRANTS RETURNS VARIANT NOT NULL LANGUAGE PYTHON CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '3.8' COMMENT = 'comment' IMPORTS = ('numpy', 'pandas') PACKAGES = ('numpy', 'pandas') HANDLER = 'udf' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) AS 'import numpy as np'`, id.FullyQualifiedName()) + }) + t.Run("all options", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -237,14 +631,14 @@ func TestFunctions_CreateForPython(t *testing.T) { opts.Arguments = []FunctionArgument{ { ArgName: "i", - ArgDataType: DataTypeNumber, + ArgDataType: dataTypeNumber, DefaultValue: String("1"), }, } opts.CopyGrants = Bool(true) opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeVariant, + ResultDataType: dataTypeVariant, }, } opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) @@ -283,7 +677,7 @@ func TestFunctions_CreateForPython(t *testing.T) { }, } opts.FunctionDefinition = String("import numpy as np") - assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (i NUMBER DEFAULT 1) COPY GRANTS RETURNS VARIANT NOT NULL LANGUAGE PYTHON CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '3.8' COMMENT = 'comment' IMPORTS = ('numpy', 'pandas') PACKAGES = ('numpy', 'pandas') HANDLER = 'udf' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) AS 'import numpy as np'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (i NUMBER(36, 2) DEFAULT 1) COPY GRANTS RETURNS VARIANT NOT NULL LANGUAGE PYTHON CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '3.8' COMMENT = 'comment' IMPORTS = ('numpy', 'pandas') PACKAGES = ('numpy', 'pandas') HANDLER = 'udf' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) AS 'import numpy as np'`, id.FullyQualifiedName()) }) } @@ -307,6 +701,43 @@ func TestFunctions_CreateForScala(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.ResultDataTypeOld opts.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaFunctionOptions", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.ResultDataTypeOld opts.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.ResultDataTypeOld = DataTypeFloat + opts.ResultDataType = dataTypeFloat + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaFunctionOptions", "ResultDataTypeOld", "ResultDataType")) + }) + t.Run("validation: function definition", func(t *testing.T) { opts := defaultOpts() opts.TargetPath = String("@~/testfunc.jar") @@ -322,10 +753,40 @@ func TestFunctions_CreateForScala(t *testing.T) { t.Run("validation: options are missing", func(t *testing.T) { opts := defaultOpts() - opts.ResultDataType = DataTypeVARCHAR + opts.ResultDataType = dataTypeVarchar assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaFunctionOptions", "Handler")) }) + // TODO [SNOW-1348103]: remove with old function removal for V1 + t.Run("all options - old data types", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Temporary = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []FunctionArgument{ + { + ArgName: "x", + ArgDataTypeOld: DataTypeVARCHAR, + DefaultValue: String("'test'"), + }, + } + opts.CopyGrants = Bool(true) + opts.ResultDataTypeOld = DataTypeVARCHAR + opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorCalledOnNullInput) + opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) + opts.RuntimeVersion = String("2.0") + opts.Comment = String("comment") + opts.Imports = []FunctionImport{ + { + Import: "@udf_libs/echohandler.jar", + }, + } + opts.Handler = "Echo.echoVarchar" + opts.FunctionDefinition = String("return x") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (x VARCHAR DEFAULT 'test') COPY GRANTS RETURNS VARCHAR NOT NULL LANGUAGE SCALA CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '2.0' COMMENT = 'comment' IMPORTS = ('@udf_libs/echohandler.jar') HANDLER = 'Echo.echoVarchar' AS 'return x'`, id.FullyQualifiedName()) + }) + t.Run("all options", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -334,12 +795,12 @@ func TestFunctions_CreateForScala(t *testing.T) { opts.Arguments = []FunctionArgument{ { ArgName: "x", - ArgDataType: DataTypeVARCHAR, + ArgDataType: dataTypeVarchar, DefaultValue: String("'test'"), }, } opts.CopyGrants = Bool(true) - opts.ResultDataType = DataTypeVARCHAR + opts.ResultDataType = dataTypeVarchar opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorCalledOnNullInput) opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) @@ -352,7 +813,7 @@ func TestFunctions_CreateForScala(t *testing.T) { } opts.Handler = "Echo.echoVarchar" opts.FunctionDefinition = String("return x") - assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (x VARCHAR DEFAULT 'test') COPY GRANTS RETURNS VARCHAR NOT NULL LANGUAGE SCALA CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '2.0' COMMENT = 'comment' IMPORTS = ('@udf_libs/echohandler.jar') HANDLER = 'Echo.echoVarchar' AS 'return x'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (x VARCHAR(100) DEFAULT 'test') COPY GRANTS RETURNS VARCHAR(100) NOT NULL LANGUAGE SCALA CALLED ON NULL INPUT IMMUTABLE RUNTIME_VERSION = '2.0' COMMENT = 'comment' IMPORTS = ('@udf_libs/echohandler.jar') HANDLER = 'Echo.echoVarchar' AS 'return x'`, id.FullyQualifiedName()) }) } @@ -376,6 +837,87 @@ func TestFunctions_CreateForSQL(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []FunctionArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = FunctionReturns{ + Table: &FunctionReturnsTable{ + Columns: []FunctionColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + t.Run("validation: returns", func(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{} @@ -386,7 +928,7 @@ func TestFunctions_CreateForSQL(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForSQLFunctionOptions", "FunctionDefinition")) @@ -396,13 +938,40 @@ func TestFunctions_CreateForSQL(t *testing.T) { opts := defaultOpts() opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } opts.FunctionDefinition = "3.141592654::FLOAT" assertOptsValidAndSQLEquals(t, opts, `CREATE FUNCTION %s () RETURNS FLOAT AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) }) + // TODO [SNOW-1348103]: remove with old function removal for V1 + t.Run("all options - old data types", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Temporary = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []FunctionArgument{ + { + ArgName: "message", + ArgDataTypeOld: "VARCHAR", + DefaultValue: String("'test'"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = FunctionReturns{ + ResultDataType: &FunctionReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + }, + } + opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) + opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) + opts.Memoizable = Bool(true) + opts.Comment = String("comment") + opts.FunctionDefinition = "3.141592654::FLOAT" + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (message VARCHAR DEFAULT 'test') COPY GRANTS RETURNS FLOAT NOT NULL IMMUTABLE MEMOIZABLE COMMENT = 'comment' AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) + }) + t.Run("all options", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) @@ -411,14 +980,14 @@ func TestFunctions_CreateForSQL(t *testing.T) { opts.Arguments = []FunctionArgument{ { ArgName: "message", - ArgDataType: "VARCHAR", + ArgDataType: dataTypeVarchar, DefaultValue: String("'test'"), }, } opts.CopyGrants = Bool(true) opts.Returns = FunctionReturns{ ResultDataType: &FunctionReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) @@ -426,7 +995,7 @@ func TestFunctions_CreateForSQL(t *testing.T) { opts.Memoizable = Bool(true) opts.Comment = String("comment") opts.FunctionDefinition = "3.141592654::FLOAT" - assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (message VARCHAR DEFAULT 'test') COPY GRANTS RETURNS FLOAT NOT NULL IMMUTABLE MEMOIZABLE COMMENT = 'comment' AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE TEMPORARY SECURE FUNCTION %s (message VARCHAR(100) DEFAULT 'test') COPY GRANTS RETURNS FLOAT NOT NULL IMMUTABLE MEMOIZABLE COMMENT = 'comment' AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) }) } diff --git a/pkg/sdk/functions_impl_gen.go b/pkg/sdk/functions_impl_gen.go index 2abf41c1e6..ca17781139 100644 --- a/pkg/sdk/functions_impl_gen.go +++ b/pkg/sdk/functions_impl_gen.go @@ -110,7 +110,8 @@ func (r *CreateForJavaFunctionRequest) toOpts() *CreateForJavaFunctionOptions { opts.Returns = FunctionReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &FunctionReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, } } if r.Returns.Table != nil { @@ -165,7 +166,8 @@ func (r *CreateForJavascriptFunctionRequest) toOpts() *CreateForJavascriptFuncti opts.Returns = FunctionReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &FunctionReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, } } if r.Returns.Table != nil { @@ -212,7 +214,8 @@ func (r *CreateForPythonFunctionRequest) toOpts() *CreateForPythonFunctionOption opts.Returns = FunctionReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &FunctionReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, } } if r.Returns.Table != nil { @@ -251,6 +254,7 @@ func (r *CreateForScalaFunctionRequest) toOpts() *CreateForScalaFunctionOptions name: r.name, CopyGrants: r.CopyGrants, + ResultDataTypeOld: r.ResultDataTypeOld, ResultDataType: r.ResultDataType, ReturnNullValues: r.ReturnNullValues, NullInputBehavior: r.NullInputBehavior, @@ -311,7 +315,8 @@ func (r *CreateForSQLFunctionRequest) toOpts() *CreateForSQLFunctionOptions { opts.Returns = FunctionReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &FunctionReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, } } if r.Returns.Table != nil { @@ -386,7 +391,7 @@ func (r functionRow) convert() *Function { if err != nil { log.Printf("[DEBUG] failed to parse function arguments, err = %s", err) } else { - e.Arguments = dataTypes + e.ArgumentsOld = dataTypes } if r.IsSecure.Valid { diff --git a/pkg/sdk/functions_validations_gen.go b/pkg/sdk/functions_validations_gen.go index 3bf1a29ff9..78970158e8 100644 --- a/pkg/sdk/functions_validations_gen.go +++ b/pkg/sdk/functions_validations_gen.go @@ -26,11 +26,35 @@ func (opts *CreateForJavaFunctionOptions) validate() error { if everyValueSet(opts.OrReplace, opts.IfNotExists) { errs = append(errs, errOneOf("CreateForJavaFunctionOptions", "OrReplace", "IfNotExists")) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForJavaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForJavaFunctionOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } + // added manually if opts.FunctionDefinition == nil { if opts.TargetPath != nil { errs = append(errs, NewError("TARGET_PATH must be nil when AS is nil")) @@ -56,10 +80,33 @@ func (opts *CreateForJavascriptFunctionOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForJavascriptFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForJavascriptFunctionOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } @@ -81,11 +128,35 @@ func (opts *CreateForPythonFunctionOptions) validate() error { if everyValueSet(opts.OrReplace, opts.IfNotExists) { errs = append(errs, errOneOf("CreateForPythonFunctionOptions", "OrReplace", "IfNotExists")) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForPythonFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForPythonFunctionOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } + // added manually if opts.FunctionDefinition == nil { if len(opts.Imports) == 0 { errs = append(errs, NewError("IMPORTS must not be empty when AS is nil")) @@ -108,6 +179,18 @@ func (opts *CreateForScalaFunctionOptions) validate() error { if everyValueSet(opts.OrReplace, opts.IfNotExists) { errs = append(errs, errOneOf("CreateForScalaFunctionOptions", "OrReplace", "IfNotExists")) } + if !exactlyOneValueSet(opts.ResultDataTypeOld, opts.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForScalaFunctionOptions", "ResultDataTypeOld", "ResultDataType")) + } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForScalaFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } + // added manually if opts.FunctionDefinition == nil { if opts.TargetPath != nil { errs = append(errs, NewError("TARGET_PATH must be nil when AS is nil")) @@ -133,10 +216,33 @@ func (opts *CreateForSQLFunctionOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLFunctionOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } diff --git a/pkg/sdk/poc/README.md b/pkg/sdk/poc/README.md index 44af1e130b..46cb6e16b9 100644 --- a/pkg/sdk/poc/README.md +++ b/pkg/sdk/poc/README.md @@ -109,6 +109,7 @@ find a better solution to solve the issue (add more logic to the templates ?) - there should be no need to define custom types every time - more clear definition of lists that can be empty vs cannot be empty - add empty ids in generated tests (TODO in random_test.go) +- add optional imports (currently they have to be added manually, e.g. `datatypes.DataType`) ##### Known issues - generating two converts when Show and Desc use the same data structure diff --git a/pkg/sdk/procedures_def.go b/pkg/sdk/procedures_def.go index 3b5eb69882..0485b7711b 100644 --- a/pkg/sdk/procedures_def.go +++ b/pkg/sdk/procedures_def.go @@ -6,19 +6,25 @@ import g "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/gen var procedureArgument = g.NewQueryStruct("ProcedureArgument"). Text("ArgName", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("ArgDataType", "DataType", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("DefaultValue", "*string", g.ParameterOptions().NoEquals().SQL("DEFAULT")) + PredefinedQueryStructField("ArgDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ArgDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). + PredefinedQueryStructField("DefaultValue", "*string", g.ParameterOptions().NoEquals().SQL("DEFAULT")). + WithValidation(g.ExactlyOneValueSet, "ArgDataTypeOld", "ArgDataType") var procedureColumn = g.NewQueryStruct("ProcedureColumn"). Text("ColumnName", g.KeywordOptions().NoQuotes().Required()). - PredefinedQueryStructField("ColumnDataType", "DataType", g.KeywordOptions().NoQuotes().Required()) + PredefinedQueryStructField("ColumnDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ColumnDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). + WithValidation(g.ExactlyOneValueSet, "ColumnDataTypeOld", "ColumnDataType") var procedureReturns = g.NewQueryStruct("ProcedureReturns"). OptionalQueryStructField( "ResultDataType", g.NewQueryStruct("ProcedureReturnsResultDataType"). - PredefinedQueryStructField("ResultDataType", "DataType", g.KeywordOptions().NoQuotes().Required()). - OptionalSQL("NULL").OptionalSQL("NOT NULL"), + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). + OptionalSQL("NULL").OptionalSQL("NOT NULL"). + WithValidation(g.ExactlyOneValueSet, "ResultDataTypeOld", "ResultDataType"), g.KeywordOptions(), ). OptionalQueryStructField( @@ -36,7 +42,9 @@ var procedureSQLReturns = g.NewQueryStruct("ProcedureSQLReturns"). OptionalQueryStructField( "ResultDataType", g.NewQueryStruct("ProcedureReturnsResultDataType"). - PredefinedQueryStructField("ResultDataType", "DataType", g.KeywordOptions().NoQuotes().Required()), + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.KeywordOptions().NoQuotes()). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). + WithValidation(g.ExactlyOneValueSet, "ResultDataTypeOld", "ResultDataType"), g.KeywordOptions(), ). OptionalQueryStructField( @@ -126,7 +134,9 @@ var ProceduresDef = g.NewInterface( g.ListOptions().MustParentheses(), ). OptionalSQL("COPY GRANTS"). - PredefinedQueryStructField("ResultDataType", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). + SQL("RETURNS"). + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals()). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). OptionalSQL("NOT NULL"). SQL("LANGUAGE JAVASCRIPT"). PredefinedQueryStructField("NullInputBehavior", "*NullInputBehavior", g.KeywordOptions()). @@ -134,7 +144,8 @@ var ProceduresDef = g.NewInterface( PredefinedQueryStructField("ExecuteAs", "*ExecuteAs", g.KeywordOptions()). PredefinedQueryStructField("ProcedureDefinition", "string", g.ParameterOptions().NoEquals().SingleQuotes().SQL("AS").Required()). WithValidation(g.ValidateValueSet, "ProcedureDefinition"). - WithValidation(g.ValidIdentifier, "name"), + WithValidation(g.ValidIdentifier, "name"). + WithValidation(g.ExactlyOneValueSet, "ResultDataTypeOld", "ResultDataType"), ).CustomOperation( "CreateForPython", "https://docs.snowflake.com/en/sql-reference/sql/create-procedure#python-handler", @@ -299,7 +310,7 @@ var ProceduresDef = g.NewInterface( Field("IsAnsi", "bool"). Field("MinNumArguments", "int"). Field("MaxNumArguments", "int"). - Field("Arguments", "string"). + Field("ArgumentsRaw", "string"). Field("Description", "string"). Field("CatalogName", "string"). Field("IsTableFunction", "bool"). @@ -437,7 +448,9 @@ var ProceduresDef = g.NewInterface( procedureArgument, g.ListOptions().MustParentheses(), ). - PredefinedQueryStructField("ResultDataType", "DataType", g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). + SQL("RETURNS"). + PredefinedQueryStructField("ResultDataTypeOld", "DataType", g.ParameterOptions().NoEquals()). + PredefinedQueryStructField("ResultDataType", "datatypes.DataType", g.ParameterOptions().NoQuotes().NoEquals().Required()). OptionalSQL("NOT NULL"). SQL("LANGUAGE JAVASCRIPT"). PredefinedQueryStructField("NullInputBehavior", "*NullInputBehavior", g.KeywordOptions()). @@ -452,7 +465,7 @@ var ProceduresDef = g.NewInterface( PredefinedQueryStructField("CallArguments", "[]string", g.KeywordOptions().MustParentheses()). PredefinedQueryStructField("ScriptingVariable", "*string", g.ParameterOptions().NoEquals().NoQuotes().SQL("INTO")). WithValidation(g.ValidateValueSet, "ProcedureDefinition"). - WithValidation(g.ValidateValueSet, "ResultDataType"). + WithValidation(g.ExactlyOneValueSet, "ResultDataTypeOld", "ResultDataType"). WithValidation(g.ValidIdentifier, "ProcedureName"). WithValidation(g.ValidIdentifier, "Name"), ).CustomOperation( diff --git a/pkg/sdk/procedures_dto_builders_gen.go b/pkg/sdk/procedures_dto_builders_gen.go index c88c8bd9ac..373852a62a 100644 --- a/pkg/sdk/procedures_dto_builders_gen.go +++ b/pkg/sdk/procedures_dto_builders_gen.go @@ -2,7 +2,10 @@ package sdk -import () +// imports added manually +import ( + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" +) func NewCreateForJavaProcedureRequest( name SchemaObjectIdentifier, @@ -82,7 +85,7 @@ func (s *CreateForJavaProcedureRequest) WithProcedureDefinition(ProcedureDefinit func NewProcedureArgumentRequest( ArgName string, - ArgDataType DataType, + ArgDataType datatypes.DataType, ) *ProcedureArgumentRequest { s := ProcedureArgumentRequest{} s.ArgName = ArgName @@ -90,6 +93,11 @@ func NewProcedureArgumentRequest( return &s } +func (s *ProcedureArgumentRequest) WithArgDataTypeOld(ArgDataTypeOld DataType) *ProcedureArgumentRequest { + s.ArgDataTypeOld = ArgDataTypeOld + return s +} + func (s *ProcedureArgumentRequest) WithDefaultValue(DefaultValue string) *ProcedureArgumentRequest { s.DefaultValue = &DefaultValue return s @@ -110,13 +118,18 @@ func (s *ProcedureReturnsRequest) WithTable(Table ProcedureReturnsTableRequest) } func NewProcedureReturnsResultDataTypeRequest( - ResultDataType DataType, + ResultDataType datatypes.DataType, ) *ProcedureReturnsResultDataTypeRequest { s := ProcedureReturnsResultDataTypeRequest{} s.ResultDataType = ResultDataType return &s } +func (s *ProcedureReturnsResultDataTypeRequest) WithResultDataTypeOld(ResultDataTypeOld DataType) *ProcedureReturnsResultDataTypeRequest { + s.ResultDataTypeOld = ResultDataTypeOld + return s +} + func (s *ProcedureReturnsResultDataTypeRequest) WithNull(Null bool) *ProcedureReturnsResultDataTypeRequest { s.Null = &Null return s @@ -138,7 +151,7 @@ func (s *ProcedureReturnsTableRequest) WithColumns(Columns []ProcedureColumnRequ func NewProcedureColumnRequest( ColumnName string, - ColumnDataType DataType, + ColumnDataType datatypes.DataType, ) *ProcedureColumnRequest { s := ProcedureColumnRequest{} s.ColumnName = ColumnName @@ -146,6 +159,11 @@ func NewProcedureColumnRequest( return &s } +func (s *ProcedureColumnRequest) WithColumnDataTypeOld(ColumnDataTypeOld DataType) *ProcedureColumnRequest { + s.ColumnDataTypeOld = ColumnDataTypeOld + return s +} + func NewProcedurePackageRequest( Package string, ) *ProcedurePackageRequest { @@ -164,7 +182,7 @@ func NewProcedureImportRequest( func NewCreateForJavaScriptProcedureRequest( name SchemaObjectIdentifier, - ResultDataType DataType, + ResultDataType datatypes.DataType, ProcedureDefinition string, ) *CreateForJavaScriptProcedureRequest { s := CreateForJavaScriptProcedureRequest{} @@ -194,6 +212,11 @@ func (s *CreateForJavaScriptProcedureRequest) WithCopyGrants(CopyGrants bool) *C return s } +func (s *CreateForJavaScriptProcedureRequest) WithResultDataTypeOld(ResultDataTypeOld DataType) *CreateForJavaScriptProcedureRequest { + s.ResultDataTypeOld = ResultDataTypeOld + return s +} + func (s *CreateForJavaScriptProcedureRequest) WithNotNull(NotNull bool) *CreateForJavaScriptProcedureRequest { s.NotNull = &NotNull return s @@ -646,7 +669,7 @@ func (s *CreateAndCallForScalaProcedureRequest) WithScriptingVariable(ScriptingV func NewCreateAndCallForJavaScriptProcedureRequest( Name AccountObjectIdentifier, - ResultDataType DataType, + ResultDataType datatypes.DataType, ProcedureDefinition string, ProcedureName AccountObjectIdentifier, ) *CreateAndCallForJavaScriptProcedureRequest { @@ -663,6 +686,11 @@ func (s *CreateAndCallForJavaScriptProcedureRequest) WithArguments(Arguments []P return s } +func (s *CreateAndCallForJavaScriptProcedureRequest) WithResultDataTypeOld(ResultDataTypeOld DataType) *CreateAndCallForJavaScriptProcedureRequest { + s.ResultDataTypeOld = ResultDataTypeOld + return s +} + func (s *CreateAndCallForJavaScriptProcedureRequest) WithNotNull(NotNull bool) *CreateAndCallForJavaScriptProcedureRequest { s.NotNull = &NotNull return s diff --git a/pkg/sdk/procedures_dto_gen.go b/pkg/sdk/procedures_dto_gen.go index 8ad24b86e6..bf3e0a8d72 100644 --- a/pkg/sdk/procedures_dto_gen.go +++ b/pkg/sdk/procedures_dto_gen.go @@ -1,5 +1,8 @@ package sdk +// imports added manually +import "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" + //go:generate go run ./dto-builder-generator/main.go var ( @@ -41,9 +44,10 @@ type CreateForJavaProcedureRequest struct { } type ProcedureArgumentRequest struct { - ArgName string // required - ArgDataType DataType // required - DefaultValue *string + ArgName string // required + ArgDataTypeOld DataType + ArgDataType datatypes.DataType // required + DefaultValue *string } type ProcedureReturnsRequest struct { @@ -52,9 +56,10 @@ type ProcedureReturnsRequest struct { } type ProcedureReturnsResultDataTypeRequest struct { - ResultDataType DataType // required - Null *bool - NotNull *bool + ResultDataTypeOld DataType + ResultDataType datatypes.DataType // required + Null *bool + NotNull *bool } type ProcedureReturnsTableRequest struct { @@ -62,8 +67,9 @@ type ProcedureReturnsTableRequest struct { } type ProcedureColumnRequest struct { - ColumnName string // required - ColumnDataType DataType // required + ColumnName string // required + ColumnDataTypeOld DataType + ColumnDataType datatypes.DataType // required } type ProcedurePackageRequest struct { @@ -80,7 +86,8 @@ type CreateForJavaScriptProcedureRequest struct { name SchemaObjectIdentifier // required Arguments []ProcedureArgumentRequest CopyGrants *bool - ResultDataType DataType // required + ResultDataTypeOld DataType + ResultDataType datatypes.DataType // required NotNull *bool NullInputBehavior *NullInputBehavior Comment *string @@ -218,7 +225,8 @@ type CreateAndCallForScalaProcedureRequest struct { type CreateAndCallForJavaScriptProcedureRequest struct { Name AccountObjectIdentifier // required Arguments []ProcedureArgumentRequest - ResultDataType DataType // required + ResultDataTypeOld DataType + ResultDataType datatypes.DataType // required NotNull *bool NullInputBehavior *NullInputBehavior ProcedureDefinition string // required diff --git a/pkg/sdk/procedures_ext.go b/pkg/sdk/procedures_ext.go new file mode 100644 index 0000000000..31307bc2fb --- /dev/null +++ b/pkg/sdk/procedures_ext.go @@ -0,0 +1,5 @@ +package sdk + +func (v *Procedure) ID() SchemaObjectIdentifierWithArguments { + return NewSchemaObjectIdentifierWithArguments(v.CatalogName, v.SchemaName, v.Name, v.ArgumentsOld...) +} diff --git a/pkg/sdk/procedures_gen.go b/pkg/sdk/procedures_gen.go index e265558f70..c65e95e94a 100644 --- a/pkg/sdk/procedures_gen.go +++ b/pkg/sdk/procedures_gen.go @@ -3,6 +3,9 @@ package sdk import ( "context" "database/sql" + + // import added manually + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type Procedures interface { @@ -49,9 +52,10 @@ type CreateForJavaProcedureOptions struct { } type ProcedureArgument struct { - ArgName string `ddl:"keyword,no_quotes"` - ArgDataType DataType `ddl:"keyword,no_quotes"` - DefaultValue *string `ddl:"parameter,no_equals" sql:"DEFAULT"` + ArgName string `ddl:"keyword,no_quotes"` + ArgDataTypeOld DataType `ddl:"keyword,no_quotes"` + ArgDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` + DefaultValue *string `ddl:"parameter,no_equals" sql:"DEFAULT"` } type ProcedureReturns struct { @@ -60,9 +64,10 @@ type ProcedureReturns struct { } type ProcedureReturnsResultDataType struct { - ResultDataType DataType `ddl:"keyword,no_quotes"` - Null *bool `ddl:"keyword" sql:"NULL"` - NotNull *bool `ddl:"keyword" sql:"NOT NULL"` + ResultDataTypeOld DataType `ddl:"keyword,no_quotes"` + ResultDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` + Null *bool `ddl:"keyword" sql:"NULL"` + NotNull *bool `ddl:"keyword" sql:"NOT NULL"` } type ProcedureReturnsTable struct { @@ -70,8 +75,9 @@ type ProcedureReturnsTable struct { } type ProcedureColumn struct { - ColumnName string `ddl:"keyword,no_quotes"` - ColumnDataType DataType `ddl:"keyword,no_quotes"` + ColumnName string `ddl:"keyword,no_quotes"` + ColumnDataTypeOld DataType `ddl:"keyword,no_quotes"` + ColumnDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` } type ProcedurePackage struct { @@ -91,7 +97,9 @@ type CreateForJavaScriptProcedureOptions struct { name SchemaObjectIdentifier `ddl:"identifier"` Arguments []ProcedureArgument `ddl:"list,must_parentheses"` CopyGrants *bool `ddl:"keyword" sql:"COPY GRANTS"` - ResultDataType DataType `ddl:"parameter,no_equals" sql:"RETURNS"` + returns bool `ddl:"static" sql:"RETURNS"` + ResultDataTypeOld DataType `ddl:"parameter,no_equals"` + ResultDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` NotNull *bool `ddl:"keyword" sql:"NOT NULL"` languageJavascript bool `ddl:"static" sql:"LANGUAGE JAVASCRIPT"` NullInputBehavior *NullInputBehavior `ddl:"keyword"` @@ -226,7 +234,7 @@ type Procedure struct { IsAnsi bool MinNumArguments int MaxNumArguments int - Arguments []DataType + ArgumentsOld []DataType ArgumentsRaw string Description string CatalogName string @@ -235,10 +243,6 @@ type Procedure struct { IsSecure bool } -func (v *Procedure) ID() SchemaObjectIdentifierWithArguments { - return NewSchemaObjectIdentifierWithArguments(v.CatalogName, v.SchemaName, v.Name, v.Arguments...) -} - // DescribeProcedureOptions is based on https://docs.snowflake.com/en/sql-reference/sql/desc-procedure. type DescribeProcedureOptions struct { describe bool `ddl:"static" sql:"DESCRIBE"` @@ -318,7 +322,9 @@ type CreateAndCallForJavaScriptProcedureOptions struct { Name AccountObjectIdentifier `ddl:"identifier"` asProcedure bool `ddl:"static" sql:"AS PROCEDURE"` Arguments []ProcedureArgument `ddl:"list,must_parentheses"` - ResultDataType DataType `ddl:"parameter,no_equals" sql:"RETURNS"` + returns bool `ddl:"static" sql:"RETURNS"` + ResultDataTypeOld DataType `ddl:"parameter,no_equals"` + ResultDataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` NotNull *bool `ddl:"keyword" sql:"NOT NULL"` languageJavascript bool `ddl:"static" sql:"LANGUAGE JAVASCRIPT"` NullInputBehavior *NullInputBehavior `ddl:"keyword"` diff --git a/pkg/sdk/procedures_gen_test.go b/pkg/sdk/procedures_gen_test.go index 7717308d51..994181b59b 100644 --- a/pkg/sdk/procedures_gen_test.go +++ b/pkg/sdk/procedures_gen_test.go @@ -9,7 +9,14 @@ func TestProcedures_CreateForJava(t *testing.T) { defaultOpts := func() *CreateForJavaProcedureOptions { return &CreateForJavaProcedureOptions{ - name: id, + name: id, + Handler: "TestFunc.echoVarchar", + Packages: []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + }, + RuntimeVersion: "1.8", } } @@ -24,7 +31,106 @@ func TestProcedures_CreateForJava(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) - t.Run("validation: returns", func(t *testing.T) { + t.Run("validation: [opts.RuntimeVersion] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.RuntimeVersion = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "RuntimeVersion")) + }) + + t.Run("validation: [opts.Packages] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Packages = nil + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "Packages")) + }) + + t.Run("validation: [opts.Handler] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Handler = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "Handler")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType opts.Returns.Table] should be present", func(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{} assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaProcedureOptions.Returns", "ResultDataType", "Table")) @@ -41,15 +147,64 @@ func TestProcedures_CreateForJava(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, NewError("TARGET_PATH must be nil when AS is nil")) }) - t.Run("validation: options are missing", func(t *testing.T) { + // TODO [SNOW-1348106]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []ProcedureArgument{ + { + ArgName: "id", + ArgDataTypeOld: DataTypeNumber, + }, + { + ArgName: "name", + ArgDataTypeOld: DataTypeVARCHAR, + DefaultValue: String("'test'"), + }, + } + opts.CopyGrants = Bool(true) opts.Returns = ProcedureReturns{ - ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + { + ColumnName: "country_code", + ColumnDataTypeOld: DataTypeVARCHAR, + }, + }, }, } - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "Handler")) - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaProcedureOptions", "RuntimeVersion")) + opts.RuntimeVersion = "1.8" + opts.Packages = []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "test_jar.jar", + }, + } + opts.Handler = "TestFunc.echoVarchar" + opts.ExternalAccessIntegrations = []AccountObjectIdentifier{ + NewAccountObjectIdentifier("ext_integration"), + } + opts.Secrets = []SecretReference{ + { + VariableName: "variable1", + Name: "name1", + }, + { + VariableName: "variable2", + Name: "name2", + }, + } + opts.TargetPath = String("@~/testfunc.jar") + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.Comment = String("test comment") + opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) + opts.ProcedureDefinition = String("return id + name;") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (id NUMBER, name VARCHAR DEFAULT 'test') COPY GRANTS RETURNS TABLE (country_code VARCHAR) LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) TARGET_PATH = '@~/testfunc.jar' STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return id + name;'`, id.FullyQualifiedName()) }) t.Run("all options", func(t *testing.T) { @@ -59,11 +214,11 @@ func TestProcedures_CreateForJava(t *testing.T) { opts.Arguments = []ProcedureArgument{ { ArgName: "id", - ArgDataType: DataTypeNumber, + ArgDataType: dataTypeNumber, }, { ArgName: "name", - ArgDataType: DataTypeVARCHAR, + ArgDataType: dataTypeVarchar, DefaultValue: String("'test'"), }, } @@ -73,7 +228,7 @@ func TestProcedures_CreateForJava(t *testing.T) { Columns: []ProcedureColumn{ { ColumnName: "country_code", - ColumnDataType: DataTypeVARCHAR, + ColumnDataType: dataTypeVarchar, }, }, }, @@ -108,7 +263,7 @@ func TestProcedures_CreateForJava(t *testing.T) { opts.Comment = String("test comment") opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) opts.ProcedureDefinition = String("return id + name;") - assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (id NUMBER, name VARCHAR DEFAULT 'test') COPY GRANTS RETURNS TABLE (country_code VARCHAR) LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) TARGET_PATH = '@~/testfunc.jar' STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return id + name;'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (id NUMBER(36, 2), name VARCHAR(100) DEFAULT 'test') COPY GRANTS RETURNS TABLE (country_code VARCHAR(100)) LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) TARGET_PATH = '@~/testfunc.jar' STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return id + name;'`, id.FullyQualifiedName()) }) } @@ -117,7 +272,8 @@ func TestProcedures_CreateForJavaScript(t *testing.T) { defaultOpts := func() *CreateForJavaScriptProcedureOptions { return &CreateForJavaScriptProcedureOptions{ - name: id, + name: id, + ProcedureDefinition: "return 1;", } } @@ -126,15 +282,75 @@ func TestProcedures_CreateForJavaScript(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) }) + t.Run("validation: [opts.ProcedureDefinition] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.ProcedureDefinition = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaScriptProcedureOptions", "ProcedureDefinition")) + }) + + t.Run("validation: exactly one field from [opts.ResultDataTypeOld opts.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaScriptProcedureOptions", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.ResultDataTypeOld opts.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.ResultDataTypeOld = DataTypeFloat + opts.ResultDataType = dataTypeFloat + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaScriptProcedureOptions", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaScriptProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaScriptProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one correct, one incorrect", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForJavaScriptProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() opts.name = emptySchemaObjectIdentifier assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) - t.Run("validation: options are missing", func(t *testing.T) { + // TODO [SNOW-1348106]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForJavaScriptProcedureOptions", "ProcedureDefinition")) + opts.OrReplace = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []ProcedureArgument{ + { + ArgName: "d", + ArgDataTypeOld: "DOUBLE", + DefaultValue: String("1.0"), + }, + } + opts.CopyGrants = Bool(true) + opts.ResultDataTypeOld = "DOUBLE" + opts.NotNull = Bool(true) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.Comment = String("test comment") + opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) + opts.ProcedureDefinition = "return 1;" + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (d DOUBLE DEFAULT 1.0) COPY GRANTS RETURNS DOUBLE NOT NULL LANGUAGE JAVASCRIPT STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return 1;'`, id.FullyQualifiedName()) }) t.Run("all options", func(t *testing.T) { @@ -144,18 +360,18 @@ func TestProcedures_CreateForJavaScript(t *testing.T) { opts.Arguments = []ProcedureArgument{ { ArgName: "d", - ArgDataType: "DOUBLE", + ArgDataType: dataTypeFloat, DefaultValue: String("1.0"), }, } opts.CopyGrants = Bool(true) - opts.ResultDataType = "DOUBLE" + opts.ResultDataType = dataTypeFloat opts.NotNull = Bool(true) opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) opts.Comment = String("test comment") opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) opts.ProcedureDefinition = "return 1;" - assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (d DOUBLE DEFAULT 1.0) COPY GRANTS RETURNS DOUBLE NOT NULL LANGUAGE JAVASCRIPT STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return 1;'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (d FLOAT DEFAULT 1.0) COPY GRANTS RETURNS FLOAT NOT NULL LANGUAGE JAVASCRIPT STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return 1;'`, id.FullyQualifiedName()) }) } @@ -164,7 +380,14 @@ func TestProcedures_CreateForPython(t *testing.T) { defaultOpts := func() *CreateForPythonProcedureOptions { return &CreateForPythonProcedureOptions{ - name: id, + name: id, + RuntimeVersion: "3.8", + Packages: []ProcedurePackage{ + { + Package: "numpy", + }, + }, + Handler: "udf", } } @@ -173,27 +396,172 @@ func TestProcedures_CreateForPython(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) }) + t.Run("validation: [opts.RuntimeVersion] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.RuntimeVersion = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonProcedureOptions", "RuntimeVersion")) + }) + + t.Run("validation: [opts.Packages] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Packages = nil + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonProcedureOptions", "Packages")) + }) + + t.Run("validation: [opts.Handler] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Handler = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonProcedureOptions", "Handler")) + }) + t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() opts.name = emptySchemaObjectIdentifier assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) - t.Run("validation: returns", func(t *testing.T) { + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType opts.Returns.Table] should be present", func(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{} assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForPythonProcedureOptions.Returns", "ResultDataType", "Table")) }) - t.Run("validation: options are missing", func(t *testing.T) { + // TODO [SNOW-1348106]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []ProcedureArgument{ + { + ArgName: "i", + ArgDataTypeOld: "int", + DefaultValue: String("1"), + }, + } + opts.CopyGrants = Bool(true) opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataTypeOld: "VARIANT", + Null: Bool(true), }, } - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonProcedureOptions", "Handler")) - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForPythonProcedureOptions", "RuntimeVersion")) + opts.RuntimeVersion = "3.8" + opts.Packages = []ProcedurePackage{ + { + Package: "numpy", + }, + { + Package: "pandas", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "numpy", + }, + { + Import: "pandas", + }, + } + opts.Handler = "udf" + opts.ExternalAccessIntegrations = []AccountObjectIdentifier{ + NewAccountObjectIdentifier("ext_integration"), + } + opts.Secrets = []SecretReference{ + { + VariableName: "variable1", + Name: "name1", + }, + { + VariableName: "variable2", + Name: "name2", + }, + } + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.Comment = String("test comment") + opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) + opts.ProcedureDefinition = String("import numpy as np") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (i int DEFAULT 1) COPY GRANTS RETURNS VARIANT NULL LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('numpy', 'pandas') IMPORTS = ('numpy', 'pandas') HANDLER = 'udf' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'import numpy as np'`, id.FullyQualifiedName()) }) t.Run("all options", func(t *testing.T) { @@ -203,14 +571,14 @@ func TestProcedures_CreateForPython(t *testing.T) { opts.Arguments = []ProcedureArgument{ { ArgName: "i", - ArgDataType: "int", + ArgDataType: dataTypeNumber, DefaultValue: String("1"), }, } opts.CopyGrants = Bool(true) opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: "VARIANT", + ResultDataType: dataTypeVariant, Null: Bool(true), }, } @@ -249,7 +617,7 @@ func TestProcedures_CreateForPython(t *testing.T) { opts.Comment = String("test comment") opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) opts.ProcedureDefinition = String("import numpy as np") - assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (i int DEFAULT 1) COPY GRANTS RETURNS VARIANT NULL LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('numpy', 'pandas') IMPORTS = ('numpy', 'pandas') HANDLER = 'udf' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'import numpy as np'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (i NUMBER(36, 2) DEFAULT 1) COPY GRANTS RETURNS VARIANT NULL LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('numpy', 'pandas') IMPORTS = ('numpy', 'pandas') HANDLER = 'udf' EXTERNAL_ACCESS_INTEGRATIONS = ("ext_integration") SECRETS = ('variable1' = name1, 'variable2' = name2) STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'import numpy as np'`, id.FullyQualifiedName()) }) } @@ -258,7 +626,14 @@ func TestProcedures_CreateForScala(t *testing.T) { defaultOpts := func() *CreateForScalaProcedureOptions { return &CreateForScalaProcedureOptions{ - name: id, + name: id, + RuntimeVersion: "2.0", + Packages: []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + }, + Handler: "Echo.echoVarchar", } } @@ -267,113 +642,360 @@ func TestProcedures_CreateForScala(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) }) + t.Run("validation: [opts.RuntimeVersion] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.RuntimeVersion = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaProcedureOptions", "RuntimeVersion")) + }) + + t.Run("validation: [opts.Packages] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Packages = nil + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaProcedureOptions", "Packages")) + }) + + t.Run("validation: [opts.Handler] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Handler = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaProcedureOptions", "Handler")) + }) + t.Run("validation: incorrect identifier", func(t *testing.T) { opts := defaultOpts() opts.name = emptySchemaObjectIdentifier assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) - t.Run("validation: returns", func(t *testing.T) { + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType opts.Returns.Table] should be present", func(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{} assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForScalaProcedureOptions.Returns", "ResultDataType", "Table")) }) - t.Run("validation: function definition", func(t *testing.T) { + t.Run("validation: function definition", func(t *testing.T) { + opts := defaultOpts() + opts.TargetPath = String("@~/testfunc.jar") + opts.Packages = []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + } + assertOptsInvalidJoinedErrors(t, opts, NewError("TARGET_PATH must be nil when AS is nil")) + }) + + // TODO [SNOW-1348106]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []ProcedureArgument{ + { + ArgName: "x", + ArgDataTypeOld: "VARCHAR", + DefaultValue: String("'test'"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataTypeOld: "VARCHAR", + NotNull: Bool(true), + }, + } + opts.RuntimeVersion = "2.0" + opts.Packages = []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "@udf_libs/echohandler.jar", + }, + } + opts.Handler = "Echo.echoVarchar" + opts.TargetPath = String("@~/testfunc.jar") + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.Comment = String("test comment") + opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) + opts.ProcedureDefinition = String("return x") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (x VARCHAR DEFAULT 'test') COPY GRANTS RETURNS VARCHAR NOT NULL LANGUAGE SCALA RUNTIME_VERSION = '2.0' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('@udf_libs/echohandler.jar') HANDLER = 'Echo.echoVarchar' TARGET_PATH = '@~/testfunc.jar' STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return x'`, id.FullyQualifiedName()) + }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []ProcedureArgument{ + { + ArgName: "x", + ArgDataType: dataTypeVarchar, + DefaultValue: String("'test'"), + }, + } + opts.CopyGrants = Bool(true) + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataType: dataTypeVarchar, + NotNull: Bool(true), + }, + } + opts.RuntimeVersion = "2.0" + opts.Packages = []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "@udf_libs/echohandler.jar", + }, + } + opts.Handler = "Echo.echoVarchar" + opts.TargetPath = String("@~/testfunc.jar") + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.Comment = String("test comment") + opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) + opts.ProcedureDefinition = String("return x") + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (x VARCHAR(100) DEFAULT 'test') COPY GRANTS RETURNS VARCHAR(100) NOT NULL LANGUAGE SCALA RUNTIME_VERSION = '2.0' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('@udf_libs/echohandler.jar') HANDLER = 'Echo.echoVarchar' TARGET_PATH = '@~/testfunc.jar' STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return x'`, id.FullyQualifiedName()) + }) +} + +func TestProcedures_CreateForSQL(t *testing.T) { + id := randomSchemaObjectIdentifier() + + defaultOpts := func() *CreateForSQLProcedureOptions { + return &CreateForSQLProcedureOptions{ + name: id, + ProcedureDefinition: "3.141592654::FLOAT", + Returns: ProcedureSQLReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataType: dataTypeVarchar, + }, + }, + } + } + + t.Run("validation: nil options", func(t *testing.T) { + var opts *CreateForSQLProcedureOptions = nil + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) + }) + + t.Run("validation: [opts.ProcedureDefinition] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.ProcedureDefinition = "" + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForSQLProcedureOptions", "ProcedureDefinition")) + }) + + t.Run("validation: incorrect identifier", func(t *testing.T) { + opts := defaultOpts() + opts.name = emptySchemaObjectIdentifier + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("create with no arguments", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureSQLReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataType: dataTypeFloat, + }, + } + opts.ProcedureDefinition = "3.141592654::FLOAT" + assertOptsValidAndSQLEquals(t, opts, `CREATE PROCEDURE %s () RETURNS FLOAT LANGUAGE SQL AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat, ArgDataType: dataTypeFloat}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Arguments.ArgDataTypeOld opts.Arguments.ArgDataType] should be present - one valid, one invalid", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + {ArgName: "arg", ArgDataTypeOld: DataTypeFloat}, + {ArgName: "arg"}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureSQLReturns{ + ResultDataType: &ProcedureReturnsResultDataType{}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.ResultDataType.ResultDataTypeOld opts.Returns.ResultDataType.ResultDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureSQLReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + ResultDataType: dataTypeFloat, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureSQLReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg"}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - two present", func(t *testing.T) { + opts := defaultOpts() + opts.Returns = ProcedureSQLReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat, ColumnDataType: dataTypeFloat}, + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + }) + + t.Run("validation: exactly one field from [opts.Returns.Table.Columns.ColumnDataTypeOld opts.Returns.Table.Columns.ColumnDataType] should be present - one valid, one invalid", func(t *testing.T) { opts := defaultOpts() - opts.TargetPath = String("@~/testfunc.jar") - opts.Packages = []ProcedurePackage{ - { - Package: "com.snowflake:snowpark:1.2.0", + opts.Returns = ProcedureSQLReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + {ColumnName: "arg", ColumnDataTypeOld: DataTypeFloat}, + {ColumnName: "arg"}, + }, }, } - assertOptsInvalidJoinedErrors(t, opts, NewError("TARGET_PATH must be nil when AS is nil")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) }) - t.Run("validation: options are missing", func(t *testing.T) { + t.Run("validation: exactly one field from [opts.Returns.ResultDataType opts.Returns.Table] should be present", func(t *testing.T) { opts := defaultOpts() - opts.Returns = ProcedureReturns{ - ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, - }, - } - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaProcedureOptions", "Handler")) - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForScalaProcedureOptions", "RuntimeVersion")) + opts.Returns = ProcedureSQLReturns{} + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateForSQLProcedureOptions.Returns", "ResultDataType", "Table")) }) - t.Run("all options", func(t *testing.T) { + // TODO [SNOW-1348106]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { opts := defaultOpts() opts.OrReplace = Bool(true) opts.Secure = Bool(true) opts.Arguments = []ProcedureArgument{ { - ArgName: "x", - ArgDataType: "VARCHAR", - DefaultValue: String("'test'"), + ArgName: "message", + ArgDataTypeOld: "VARCHAR", + DefaultValue: String("'test'"), }, } opts.CopyGrants = Bool(true) - opts.Returns = ProcedureReturns{ + opts.Returns = ProcedureSQLReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: "VARCHAR", - NotNull: Bool(true), - }, - } - opts.RuntimeVersion = "2.0" - opts.Packages = []ProcedurePackage{ - { - Package: "com.snowflake:snowpark:1.2.0", - }, - } - opts.Imports = []ProcedureImport{ - { - Import: "@udf_libs/echohandler.jar", + ResultDataTypeOld: "VARCHAR", }, + NotNull: Bool(true), } - opts.Handler = "Echo.echoVarchar" - opts.TargetPath = String("@~/testfunc.jar") opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) opts.Comment = String("test comment") opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) - opts.ProcedureDefinition = String("return x") - assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (x VARCHAR DEFAULT 'test') COPY GRANTS RETURNS VARCHAR NOT NULL LANGUAGE SCALA RUNTIME_VERSION = '2.0' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('@udf_libs/echohandler.jar') HANDLER = 'Echo.echoVarchar' TARGET_PATH = '@~/testfunc.jar' STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS 'return x'`, id.FullyQualifiedName()) - }) -} - -func TestProcedures_CreateForSQL(t *testing.T) { - id := randomSchemaObjectIdentifier() - - defaultOpts := func() *CreateForSQLProcedureOptions { - return &CreateForSQLProcedureOptions{ - name: id, - } - } - - t.Run("validation: nil options", func(t *testing.T) { - var opts *CreateForSQLProcedureOptions = nil - assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) - }) - - t.Run("validation: incorrect identifier", func(t *testing.T) { - opts := defaultOpts() - opts.name = emptySchemaObjectIdentifier - assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) - }) - - t.Run("validation: options are missing", func(t *testing.T) { - opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateForSQLProcedureOptions", "ProcedureDefinition")) - }) - - t.Run("create with no arguments", func(t *testing.T) { - opts := defaultOpts() - opts.Returns = ProcedureSQLReturns{ - ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeFloat, - }, - } opts.ProcedureDefinition = "3.141592654::FLOAT" - assertOptsValidAndSQLEquals(t, opts, `CREATE PROCEDURE %s () RETURNS FLOAT LANGUAGE SQL AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (message VARCHAR DEFAULT 'test') COPY GRANTS RETURNS VARCHAR NOT NULL LANGUAGE SQL STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) }) t.Run("all options", func(t *testing.T) { @@ -383,14 +1005,14 @@ func TestProcedures_CreateForSQL(t *testing.T) { opts.Arguments = []ProcedureArgument{ { ArgName: "message", - ArgDataType: "VARCHAR", + ArgDataType: dataTypeVarchar, DefaultValue: String("'test'"), }, } opts.CopyGrants = Bool(true) opts.Returns = ProcedureSQLReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: "VARCHAR", + ResultDataType: dataTypeVarchar, }, NotNull: Bool(true), } @@ -398,7 +1020,7 @@ func TestProcedures_CreateForSQL(t *testing.T) { opts.Comment = String("test comment") opts.ExecuteAs = ExecuteAsPointer(ExecuteAsCaller) opts.ProcedureDefinition = "3.141592654::FLOAT" - assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (message VARCHAR DEFAULT 'test') COPY GRANTS RETURNS VARCHAR NOT NULL LANGUAGE SQL STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE PROCEDURE %s (message VARCHAR(100) DEFAULT 'test') COPY GRANTS RETURNS VARCHAR(100) NOT NULL LANGUAGE SQL STRICT COMMENT = 'test comment' EXECUTE AS CALLER AS '3.141592654::FLOAT'`, id.FullyQualifiedName()) }) } @@ -667,12 +1289,12 @@ func TestProcedures_CreateAndCallForJava(t *testing.T) { Columns: []ProcedureColumn{ { ColumnName: "name", - ColumnDataType: DataTypeVARCHAR, + ColumnDataType: dataTypeVarchar, }, }, }, ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForJavaProcedureOptions.Returns", "ResultDataType", "Table")) @@ -682,7 +1304,7 @@ func TestProcedures_CreateAndCallForJava(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateAndCallForJavaProcedureOptions", "Handler")) @@ -706,16 +1328,65 @@ func TestProcedures_CreateAndCallForJava(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS TABLE () LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:latest') HANDLER = 'TestFunc.echoVarchar' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) + // TODO [SNOW-1348106]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + { + ArgName: "id", + ArgDataTypeOld: DataTypeNumber, + }, + { + ArgName: "name", + ArgDataTypeOld: DataTypeVARCHAR, + }, + } + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + { + ColumnName: "country_code", + ColumnDataTypeOld: DataTypeVARCHAR, + }, + }, + }, + } + opts.RuntimeVersion = "1.8" + opts.Packages = []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "test_jar.jar", + }, + } + opts.Handler = "TestFunc.echoVarchar" + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.ProcedureDefinition = String("return id + name;") + cte := NewAccountObjectIdentifier("album_info_1976") + opts.WithClause = &ProcedureWithClause{ + CteName: cte, + CteColumns: []string{"x", "y"}, + Statement: "(select m.album_ID, m.album_name, b.band_name from music_albums)", + } + opts.ProcedureName = id + opts.ScriptingVariable = String(":ret") + opts.CallArguments = []string{"1", "rnd"} + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (id NUMBER, name VARCHAR) RETURNS TABLE (country_code VARCHAR) LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' STRICT AS 'return id + name;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1, rnd) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + }) + t.Run("all options", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ { ArgName: "id", - ArgDataType: DataTypeNumber, + ArgDataType: dataTypeNumber, }, { ArgName: "name", - ArgDataType: DataTypeVARCHAR, + ArgDataType: dataTypeVarchar, }, } opts.Returns = ProcedureReturns{ @@ -723,7 +1394,7 @@ func TestProcedures_CreateAndCallForJava(t *testing.T) { Columns: []ProcedureColumn{ { ColumnName: "country_code", - ColumnDataType: DataTypeVARCHAR, + ColumnDataType: dataTypeVarchar, }, }, }, @@ -751,7 +1422,7 @@ func TestProcedures_CreateAndCallForJava(t *testing.T) { opts.ProcedureName = id opts.ScriptingVariable = String(":ret") opts.CallArguments = []string{"1", "rnd"} - assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (id NUMBER, name VARCHAR) RETURNS TABLE (country_code VARCHAR) LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' STRICT AS 'return id + name;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1, rnd) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (id NUMBER(36, 2), name VARCHAR(100)) RETURNS TABLE (country_code VARCHAR(100)) LANGUAGE JAVA RUNTIME_VERSION = '1.8' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' STRICT AS 'return id + name;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1, rnd) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) }) } @@ -788,12 +1459,12 @@ func TestProcedures_CreateAndCallForScala(t *testing.T) { Columns: []ProcedureColumn{ { ColumnName: "name", - ColumnDataType: DataTypeVARCHAR, + ColumnDataType: dataTypeVarchar, }, }, }, ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForScalaProcedureOptions.Returns", "ResultDataType", "Table")) @@ -803,7 +1474,7 @@ func TestProcedures_CreateAndCallForScala(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateAndCallForScalaProcedureOptions", "Handler")) @@ -827,16 +1498,67 @@ func TestProcedures_CreateAndCallForScala(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS TABLE () LANGUAGE SCALA RUNTIME_VERSION = '2.12' PACKAGES = ('com.snowflake:snowpark:1.2.0') HANDLER = 'TestFunc.echoVarchar' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) + // TODO [SNOW-1348106]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + { + ArgName: "id", + ArgDataTypeOld: DataTypeNumber, + }, + { + ArgName: "name", + ArgDataTypeOld: DataTypeVARCHAR, + }, + } + opts.Returns = ProcedureReturns{ + Table: &ProcedureReturnsTable{ + Columns: []ProcedureColumn{ + { + ColumnName: "country_code", + ColumnDataTypeOld: DataTypeVARCHAR, + }, + }, + }, + } + opts.RuntimeVersion = "2.12" + opts.Packages = []ProcedurePackage{ + { + Package: "com.snowflake:snowpark:1.2.0", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "test_jar.jar", + }, + } + opts.Handler = "TestFunc.echoVarchar" + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.ProcedureDefinition = String("return id + name;") + cte := NewAccountObjectIdentifier("album_info_1976") + opts.WithClauses = []ProcedureWithClause{ + { + CteName: cte, + CteColumns: []string{"x", "y"}, + Statement: "(select m.album_ID, m.album_name, b.band_name from music_albums)", + }, + } + opts.ProcedureName = id + opts.ScriptingVariable = String(":ret") + opts.CallArguments = []string{"1", "rnd"} + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (id NUMBER, name VARCHAR) RETURNS TABLE (country_code VARCHAR) LANGUAGE SCALA RUNTIME_VERSION = '2.12' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' STRICT AS 'return id + name;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1, rnd) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + }) + t.Run("all options", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ { ArgName: "id", - ArgDataType: DataTypeNumber, + ArgDataType: dataTypeNumber, }, { ArgName: "name", - ArgDataType: DataTypeVARCHAR, + ArgDataType: dataTypeVarchar, }, } opts.Returns = ProcedureReturns{ @@ -844,7 +1566,7 @@ func TestProcedures_CreateAndCallForScala(t *testing.T) { Columns: []ProcedureColumn{ { ColumnName: "country_code", - ColumnDataType: DataTypeVARCHAR, + ColumnDataType: dataTypeVarchar, }, }, }, @@ -874,7 +1596,7 @@ func TestProcedures_CreateAndCallForScala(t *testing.T) { opts.ProcedureName = id opts.ScriptingVariable = String(":ret") opts.CallArguments = []string{"1", "rnd"} - assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (id NUMBER, name VARCHAR) RETURNS TABLE (country_code VARCHAR) LANGUAGE SCALA RUNTIME_VERSION = '2.12' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' STRICT AS 'return id + name;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1, rnd) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (id NUMBER(36, 2), name VARCHAR(100)) RETURNS TABLE (country_code VARCHAR(100)) LANGUAGE SCALA RUNTIME_VERSION = '2.12' PACKAGES = ('com.snowflake:snowpark:1.2.0') IMPORTS = ('test_jar.jar') HANDLER = 'TestFunc.echoVarchar' STRICT AS 'return id + name;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1, rnd) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) }) } @@ -911,12 +1633,12 @@ func TestProcedures_CreateAndCallForPython(t *testing.T) { Columns: []ProcedureColumn{ { ColumnName: "name", - ColumnDataType: DataTypeVARCHAR, + ColumnDataType: dataTypeVarchar, }, }, }, ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForPythonProcedureOptions.Returns", "ResultDataType", "Table")) @@ -926,7 +1648,7 @@ func TestProcedures_CreateAndCallForPython(t *testing.T) { opts := defaultOpts() opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeVARCHAR, + ResultDataType: dataTypeVarchar, }, } assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateAndCallForPythonProcedureOptions", "Handler")) @@ -950,18 +1672,68 @@ func TestProcedures_CreateAndCallForPython(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS TABLE () LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('snowflake-snowpark-python') HANDLER = 'udf' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) + // TODO [SNOW-1348106]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + { + ArgName: "i", + ArgDataTypeOld: "int", + DefaultValue: String("1"), + }, + } + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataTypeOld: "VARIANT", + Null: Bool(true), + }, + } + opts.RuntimeVersion = "3.8" + opts.Packages = []ProcedurePackage{ + { + Package: "numpy", + }, + { + Package: "pandas", + }, + } + opts.Imports = []ProcedureImport{ + { + Import: "numpy", + }, + { + Import: "pandas", + }, + } + opts.Handler = "udf" + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.ProcedureDefinition = String("import numpy as np") + cte := NewAccountObjectIdentifier("album_info_1976") + opts.WithClauses = []ProcedureWithClause{ + { + CteName: cte, + CteColumns: []string{"x", "y"}, + Statement: "(select m.album_ID, m.album_name, b.band_name from music_albums)", + }, + } + opts.ProcedureName = id + opts.ScriptingVariable = String(":ret") + opts.CallArguments = []string{"1"} + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (i int DEFAULT 1) RETURNS VARIANT NULL LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('numpy', 'pandas') IMPORTS = ('numpy', 'pandas') HANDLER = 'udf' STRICT AS 'import numpy as np' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + }) + t.Run("all options", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ { ArgName: "i", - ArgDataType: "int", + ArgDataType: dataTypeNumber, DefaultValue: String("1"), }, } opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: "VARIANT", + ResultDataType: dataTypeVariant, Null: Bool(true), }, } @@ -996,7 +1768,7 @@ func TestProcedures_CreateAndCallForPython(t *testing.T) { opts.ProcedureName = id opts.ScriptingVariable = String(":ret") opts.CallArguments = []string{"1"} - assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (i int DEFAULT 1) RETURNS VARIANT NULL LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('numpy', 'pandas') IMPORTS = ('numpy', 'pandas') HANDLER = 'udf' STRICT AS 'import numpy as np' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (i NUMBER(36, 2) DEFAULT 1) RETURNS VARIANT NULL LANGUAGE PYTHON RUNTIME_VERSION = '3.8' PACKAGES = ('numpy', 'pandas') IMPORTS = ('numpy', 'pandas') HANDLER = 'udf' STRICT AS 'import numpy as np' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) }) } @@ -1028,10 +1800,38 @@ func TestProcedures_CreateAndCallForJavaScript(t *testing.T) { t.Run("no arguments", func(t *testing.T) { opts := defaultOpts() - opts.ResultDataType = "DOUBLE" + opts.ResultDataType = dataTypeFloat + opts.ProcedureDefinition = "return 1;" + opts.ProcedureName = id + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS FLOAT LANGUAGE JAVASCRIPT AS 'return 1;' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) + }) + + // TODO [SNOW-1348106]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + { + ArgName: "d", + ArgDataTypeOld: "DOUBLE", + DefaultValue: String("1.0"), + }, + } + opts.ResultDataTypeOld = "DOUBLE" + opts.NotNull = Bool(true) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) opts.ProcedureDefinition = "return 1;" + cte := NewAccountObjectIdentifier("album_info_1976") + opts.WithClauses = []ProcedureWithClause{ + { + CteName: cte, + CteColumns: []string{"x", "y"}, + Statement: "(select m.album_ID, m.album_name, b.band_name from music_albums)", + }, + } opts.ProcedureName = id - assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS DOUBLE LANGUAGE JAVASCRIPT AS 'return 1;' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) + opts.ScriptingVariable = String(":ret") + opts.CallArguments = []string{"1"} + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (d DOUBLE DEFAULT 1.0) RETURNS DOUBLE NOT NULL LANGUAGE JAVASCRIPT STRICT AS 'return 1;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) }) t.Run("all options", func(t *testing.T) { @@ -1039,11 +1839,11 @@ func TestProcedures_CreateAndCallForJavaScript(t *testing.T) { opts.Arguments = []ProcedureArgument{ { ArgName: "d", - ArgDataType: "DOUBLE", + ArgDataType: dataTypeFloat, DefaultValue: String("1.0"), }, } - opts.ResultDataType = "DOUBLE" + opts.ResultDataType = dataTypeFloat opts.NotNull = Bool(true) opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) opts.ProcedureDefinition = "return 1;" @@ -1058,7 +1858,7 @@ func TestProcedures_CreateAndCallForJavaScript(t *testing.T) { opts.ProcedureName = id opts.ScriptingVariable = String(":ret") opts.CallArguments = []string{"1"} - assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (d DOUBLE DEFAULT 1.0) RETURNS DOUBLE NOT NULL LANGUAGE JAVASCRIPT STRICT AS 'return 1;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (d FLOAT DEFAULT 1.0) RETURNS FLOAT NOT NULL LANGUAGE JAVASCRIPT STRICT AS 'return 1;' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) }) } @@ -1095,12 +1895,12 @@ func TestProcedures_CreateAndCallForSQL(t *testing.T) { Columns: []ProcedureColumn{ { ColumnName: "name", - ColumnDataType: DataTypeVARCHAR, + ColumnDataType: dataTypeVarchar, }, }, }, ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns", "ResultDataType", "Table")) @@ -1122,18 +1922,49 @@ func TestProcedures_CreateAndCallForSQL(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE () RETURNS TABLE () LANGUAGE SQL AS '3.141592654::FLOAT' CALL %s ()`, id.FullyQualifiedName(), id.FullyQualifiedName()) }) + // TODO [SNOW-1348106]: remove with old procedure removal for V1 + t.Run("all options - old data types", func(t *testing.T) { + opts := defaultOpts() + opts.Arguments = []ProcedureArgument{ + { + ArgName: "message", + ArgDataTypeOld: "VARCHAR", + DefaultValue: String("'test'"), + }, + } + opts.Returns = ProcedureReturns{ + ResultDataType: &ProcedureReturnsResultDataType{ + ResultDataTypeOld: DataTypeFloat, + }, + } + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) + opts.ProcedureDefinition = "3.141592654::FLOAT" + cte := NewAccountObjectIdentifier("album_info_1976") + opts.WithClauses = []ProcedureWithClause{ + { + CteName: cte, + CteColumns: []string{"x", "y"}, + Statement: "(select m.album_ID, m.album_name, b.band_name from music_albums)", + }, + } + opts.ProcedureName = id + opts.ScriptingVariable = String(":ret") + opts.CallArguments = []string{"1"} + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (message VARCHAR DEFAULT 'test') RETURNS FLOAT LANGUAGE SQL STRICT AS '3.141592654::FLOAT' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + }) + t.Run("all options", func(t *testing.T) { opts := defaultOpts() opts.Arguments = []ProcedureArgument{ { ArgName: "message", - ArgDataType: "VARCHAR", + ArgDataType: dataTypeVarchar, DefaultValue: String("'test'"), }, } opts.Returns = ProcedureReturns{ ResultDataType: &ProcedureReturnsResultDataType{ - ResultDataType: DataTypeFloat, + ResultDataType: dataTypeFloat, }, } opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorStrict) @@ -1149,6 +1980,6 @@ func TestProcedures_CreateAndCallForSQL(t *testing.T) { opts.ProcedureName = id opts.ScriptingVariable = String(":ret") opts.CallArguments = []string{"1"} - assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (message VARCHAR DEFAULT 'test') RETURNS FLOAT LANGUAGE SQL STRICT AS '3.141592654::FLOAT' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, `WITH %s AS PROCEDURE (message VARCHAR(100) DEFAULT 'test') RETURNS FLOAT LANGUAGE SQL STRICT AS '3.141592654::FLOAT' , %s (x, y) AS (select m.album_ID, m.album_name, b.band_name from music_albums) CALL %s (1) INTO :ret`, id.FullyQualifiedName(), cte.FullyQualifiedName(), id.FullyQualifiedName()) }) } diff --git a/pkg/sdk/procedures_impl_gen.go b/pkg/sdk/procedures_impl_gen.go index 80cb096373..e63cf1f386 100644 --- a/pkg/sdk/procedures_impl_gen.go +++ b/pkg/sdk/procedures_impl_gen.go @@ -137,9 +137,10 @@ func (r *CreateForJavaProcedureRequest) toOpts() *CreateForJavaProcedureOptions opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { @@ -176,6 +177,7 @@ func (r *CreateForJavaScriptProcedureRequest) toOpts() *CreateForJavaScriptProce name: r.name, CopyGrants: r.CopyGrants, + ResultDataTypeOld: r.ResultDataTypeOld, ResultDataType: r.ResultDataType, NotNull: r.NotNull, NullInputBehavior: r.NullInputBehavior, @@ -221,9 +223,10 @@ func (r *CreateForPythonProcedureRequest) toOpts() *CreateForPythonProcedureOpti opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { @@ -280,9 +283,10 @@ func (r *CreateForScalaProcedureRequest) toOpts() *CreateForScalaProcedureOption opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { @@ -337,7 +341,8 @@ func (r *CreateForSQLProcedureRequest) toOpts() *CreateForSQLProcedureOptions { } if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, } } if r.Returns.Table != nil { @@ -407,7 +412,7 @@ func (r procedureRow) convert() *Procedure { if err != nil { log.Printf("[DEBUG] failed to parse procedure arguments, err = %s", err) } else { - e.Arguments = dataTypes + e.ArgumentsOld = dataTypes } if r.IsSecure.Valid { e.IsSecure = r.IsSecure.String == "Y" @@ -466,9 +471,10 @@ func (r *CreateAndCallForJavaProcedureRequest) toOpts() *CreateAndCallForJavaPro opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { @@ -529,9 +535,10 @@ func (r *CreateAndCallForScalaProcedureRequest) toOpts() *CreateAndCallForScalaP opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { @@ -576,6 +583,7 @@ func (r *CreateAndCallForJavaScriptProcedureRequest) toOpts() *CreateAndCallForJ opts := &CreateAndCallForJavaScriptProcedureOptions{ Name: r.Name, + ResultDataTypeOld: r.ResultDataTypeOld, ResultDataType: r.ResultDataType, NotNull: r.NotNull, NullInputBehavior: r.NullInputBehavior, @@ -630,9 +638,10 @@ func (r *CreateAndCallForPythonProcedureRequest) toOpts() *CreateAndCallForPytho opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { @@ -694,9 +703,10 @@ func (r *CreateAndCallForSQLProcedureRequest) toOpts() *CreateAndCallForSQLProce opts.Returns = ProcedureReturns{} if r.Returns.ResultDataType != nil { opts.Returns.ResultDataType = &ProcedureReturnsResultDataType{ - ResultDataType: r.Returns.ResultDataType.ResultDataType, - Null: r.Returns.ResultDataType.Null, - NotNull: r.Returns.ResultDataType.NotNull, + ResultDataTypeOld: r.Returns.ResultDataType.ResultDataTypeOld, + ResultDataType: r.Returns.ResultDataType.ResultDataType, + Null: r.Returns.ResultDataType.Null, + NotNull: r.Returns.ResultDataType.NotNull, } } if r.Returns.Table != nil { diff --git a/pkg/sdk/procedures_validations_gen.go b/pkg/sdk/procedures_validations_gen.go index c397a41500..5e7557176f 100644 --- a/pkg/sdk/procedures_validations_gen.go +++ b/pkg/sdk/procedures_validations_gen.go @@ -35,11 +35,35 @@ func (opts *CreateForJavaProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForJavaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForJavaProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } + // added manually if opts.ProcedureDefinition == nil && opts.TargetPath != nil { errs = append(errs, NewError("TARGET_PATH must be nil when AS is nil")) } @@ -57,6 +81,17 @@ func (opts *CreateForJavaScriptProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if !exactlyOneValueSet(opts.ResultDataTypeOld, opts.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForJavaScriptProcedureOptions", "ResultDataTypeOld", "ResultDataType")) + } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForJavaScriptProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } return JoinErrors(errs...) } @@ -77,10 +112,33 @@ func (opts *CreateForPythonProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForPythonProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForPythonProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } @@ -102,11 +160,35 @@ func (opts *CreateForScalaProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForScalaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForScalaProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } + // added manually if opts.ProcedureDefinition == nil && opts.TargetPath != nil { errs = append(errs, NewError("TARGET_PATH must be nil when AS is nil")) } @@ -124,10 +206,33 @@ func (opts *CreateForSQLProcedureOptions) validate() error { if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateForSQLProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } @@ -205,15 +310,39 @@ func (opts *CreateAndCallForJavaProcedureOptions) validate() error { errs = append(errs, errNotSet("CreateAndCallForJavaProcedureOptions", "Handler")) } if !ValidObjectIdentifier(opts.ProcedureName) { + // altered manually errs = append(errs, errInvalidIdentifier("CreateAndCallForJavaProcedureOptions", "ProcedureName")) } if !ValidObjectIdentifier(opts.Name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForJavaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateAndCallForJavaProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } @@ -233,15 +362,39 @@ func (opts *CreateAndCallForScalaProcedureOptions) validate() error { errs = append(errs, errNotSet("CreateAndCallForScalaProcedureOptions", "Handler")) } if !ValidObjectIdentifier(opts.ProcedureName) { + // altered manually errs = append(errs, errInvalidIdentifier("CreateAndCallForScalaProcedureOptions", "ProcedureName")) } if !ValidObjectIdentifier(opts.Name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForScalaProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateAndCallForScalaProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } @@ -254,15 +407,24 @@ func (opts *CreateAndCallForJavaScriptProcedureOptions) validate() error { if !valueSet(opts.ProcedureDefinition) { errs = append(errs, errNotSet("CreateAndCallForJavaScriptProcedureOptions", "ProcedureDefinition")) } - if !valueSet(opts.ResultDataType) { - errs = append(errs, errNotSet("CreateAndCallForJavaScriptProcedureOptions", "ResultDataType")) + if !exactlyOneValueSet(opts.ResultDataTypeOld, opts.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForJavaScriptProcedureOptions", "ResultDataTypeOld", "ResultDataType")) } if !ValidObjectIdentifier(opts.ProcedureName) { + // altered manually errs = append(errs, errInvalidIdentifier("CreateAndCallForJavaScriptProcedureOptions", "ProcedureName")) } if !ValidObjectIdentifier(opts.Name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForJavaScriptProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } return JoinErrors(errs...) } @@ -281,15 +443,39 @@ func (opts *CreateAndCallForPythonProcedureOptions) validate() error { errs = append(errs, errNotSet("CreateAndCallForPythonProcedureOptions", "Handler")) } if !ValidObjectIdentifier(opts.ProcedureName) { + // altered manually errs = append(errs, errInvalidIdentifier("CreateAndCallForPythonProcedureOptions", "ProcedureName")) } if !ValidObjectIdentifier(opts.Name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForPythonProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateAndCallForPythonProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } @@ -303,15 +489,39 @@ func (opts *CreateAndCallForSQLProcedureOptions) validate() error { errs = append(errs, errNotSet("CreateAndCallForSQLProcedureOptions", "ProcedureDefinition")) } if !ValidObjectIdentifier(opts.ProcedureName) { + // altered manually errs = append(errs, errInvalidIdentifier("CreateAndCallForSQLProcedureOptions", "ProcedureName")) } if !ValidObjectIdentifier(opts.Name) { errs = append(errs, ErrInvalidObjectIdentifier) } + if valueSet(opts.Arguments) { + // modified manually + for _, arg := range opts.Arguments { + if !exactlyOneValueSet(arg.ArgDataTypeOld, arg.ArgDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Arguments", "ArgDataTypeOld", "ArgDataType")) + } + } + } if valueSet(opts.Returns) { if !exactlyOneValueSet(opts.Returns.ResultDataType, opts.Returns.Table) { errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns", "ResultDataType", "Table")) } + if valueSet(opts.Returns.ResultDataType) { + if !exactlyOneValueSet(opts.Returns.ResultDataType.ResultDataTypeOld, opts.Returns.ResultDataType.ResultDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.ResultDataType", "ResultDataTypeOld", "ResultDataType")) + } + } + if valueSet(opts.Returns.Table) { + if valueSet(opts.Returns.Table.Columns) { + // modified manually + for _, col := range opts.Returns.Table.Columns { + if !exactlyOneValueSet(col.ColumnDataTypeOld, col.ColumnDataType) { + errs = append(errs, errExactlyOneOf("CreateAndCallForSQLProcedureOptions.Returns.Table.Columns", "ColumnDataTypeOld", "ColumnDataType")) + } + } + } + } } return JoinErrors(errs...) } diff --git a/pkg/sdk/random_test.go b/pkg/sdk/random_test.go index 552eb68c15..83880167df 100644 --- a/pkg/sdk/random_test.go +++ b/pkg/sdk/random_test.go @@ -2,6 +2,7 @@ package sdk import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) var ( @@ -14,6 +15,12 @@ var ( emptyDatabaseObjectIdentifier = NewDatabaseObjectIdentifier("", "") emptySchemaObjectIdentifier = NewSchemaObjectIdentifier("", "", "") emptySchemaObjectIdentifierWithArguments = NewSchemaObjectIdentifierWithArguments("", "", "") + + // TODO [SNOW-1843440]: create using constructors (when we add them)? + dataTypeNumber, _ = datatypes.ParseDataType("NUMBER(36, 2)") + dataTypeVarchar, _ = datatypes.ParseDataType("VARCHAR(100)") + dataTypeFloat, _ = datatypes.ParseDataType("FLOAT") + dataTypeVariant, _ = datatypes.ParseDataType("VARIANT") ) func randomSchemaObjectIdentifierWithArguments(argumentDataTypes ...DataType) SchemaObjectIdentifierWithArguments { diff --git a/pkg/sdk/sql_builder.go b/pkg/sdk/sql_builder.go index bae8ef485c..ba50ac65f6 100644 --- a/pkg/sdk/sql_builder.go +++ b/pkg/sdk/sql_builder.go @@ -6,6 +6,8 @@ import ( "strings" "time" "unsafe" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type modifierType string @@ -642,7 +644,15 @@ func (v sqlParameterClause) String() string { if v.value == nil { return s } + value := v.value + if dataType, ok := value.(datatypes.DataType); ok { + // We check like this and not by `dataType == nil` because for e.g. `var *datatypes.ArrayDataType` return false in a normal nil check + if reflect.ValueOf(dataType).IsZero() { + return s + } + value = dataType.ToSql() + } // key = "value" - s += v.qm.Modify(v.value) + s += v.qm.Modify(value) return s } diff --git a/pkg/sdk/sql_builder_test.go b/pkg/sdk/sql_builder_test.go index b5eac3de58..48289cdad1 100644 --- a/pkg/sdk/sql_builder_test.go +++ b/pkg/sdk/sql_builder_test.go @@ -1,9 +1,11 @@ package sdk import ( + "fmt" "reflect" "testing" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -472,3 +474,126 @@ func TestBuilder_sql(t *testing.T) { assert.Equal(t, "EXAMPLE_STATIC EXAMPLE_KEYWORD = example", s) }) } + +func TestBuilder_DataType(t *testing.T) { + type dataTypeTestHelper struct { + DataType datatypes.DataType `ddl:"parameter,no_quotes,no_equals"` + } + + dataTypes := []struct { + dataType string + expectedSql string + }{ + {dataType: "ARRAY", expectedSql: "ARRAY"}, + {dataType: "array", expectedSql: "ARRAY"}, + {dataType: "BINARY", expectedSql: "BINARY(8388608)"}, + {dataType: "binary(120)", expectedSql: "BINARY(120)"}, + {dataType: "BOOLEAN", expectedSql: "BOOLEAN"}, + {dataType: "boolean", expectedSql: "BOOLEAN"}, + {dataType: "DATE", expectedSql: "DATE"}, + {dataType: "date", expectedSql: "DATE"}, + {dataType: "FLOAT", expectedSql: "FLOAT"}, + {dataType: "float4", expectedSql: "FLOAT4"}, + {dataType: "real", expectedSql: "REAL"}, + {dataType: "GEOGRAPHY", expectedSql: "GEOGRAPHY"}, + {dataType: "geography", expectedSql: "GEOGRAPHY"}, + {dataType: "GEOMETRY", expectedSql: "GEOMETRY"}, + {dataType: "geometry", expectedSql: "GEOMETRY"}, + {dataType: "NUMBER", expectedSql: "NUMBER(38, 0)"}, + {dataType: "NUMBER(36)", expectedSql: "NUMBER(36, 0)"}, + {dataType: "NUMBER(36, 2)", expectedSql: "NUMBER(36, 2)"}, + {dataType: "number(36, 2)", expectedSql: "NUMBER(36, 2)"}, + {dataType: "INT", expectedSql: "INT"}, + {dataType: "integer", expectedSql: "INTEGER"}, + {dataType: "OBJECT", expectedSql: "OBJECT"}, + {dataType: "object", expectedSql: "OBJECT"}, + {dataType: "VARCHAR(20)", expectedSql: "VARCHAR(20)"}, + {dataType: "VARCHAR", expectedSql: "VARCHAR(16777216)"}, + {dataType: "varchar", expectedSql: "VARCHAR(16777216)"}, + {dataType: "CHAR", expectedSql: "CHAR(1)"}, + {dataType: "char(34)", expectedSql: "CHAR(34)"}, + {dataType: "TIME", expectedSql: "TIME(9)"}, + {dataType: "time", expectedSql: "TIME(9)"}, + {dataType: "time(5)", expectedSql: "TIME(5)"}, + {dataType: "TIMESTAMP_LTZ", expectedSql: "TIMESTAMP_LTZ(9)"}, + {dataType: "timestamp_ltz", expectedSql: "TIMESTAMP_LTZ(9)"}, + {dataType: "timestampltz", expectedSql: "TIMESTAMPLTZ(9)"}, + {dataType: "timestampltz(5)", expectedSql: "TIMESTAMPLTZ(5)"}, + {dataType: "TIMESTAMP_NTZ", expectedSql: "TIMESTAMP_NTZ(9)"}, + {dataType: "timestamp_ntz", expectedSql: "TIMESTAMP_NTZ(9)"}, + {dataType: "timestamp_ntz(5)", expectedSql: "TIMESTAMP_NTZ(5)"}, + {dataType: "timestampntz", expectedSql: "TIMESTAMPNTZ(9)"}, + {dataType: "timestampntz(5)", expectedSql: "TIMESTAMPNTZ(5)"}, + {dataType: "TIMESTAMP_TZ", expectedSql: "TIMESTAMP_TZ(9)"}, + {dataType: "timestamp_tz", expectedSql: "TIMESTAMP_TZ(9)"}, + {dataType: "timestamp_tz(5)", expectedSql: "TIMESTAMP_TZ(5)"}, + {dataType: "timestamptz", expectedSql: "TIMESTAMPTZ(9)"}, + {dataType: "timestamptz(5)", expectedSql: "TIMESTAMPTZ(5)"}, + {dataType: "VARIANT", expectedSql: "VARIANT"}, + {dataType: "variant", expectedSql: "VARIANT"}, + {dataType: "VECTOR(INT, 20)", expectedSql: "VECTOR(INT, 20)"}, + {dataType: "VECTOR(FLOAT, 20)", expectedSql: "VECTOR(FLOAT, 20)"}, + {dataType: "VECTOR(int, 20)", expectedSql: "VECTOR(INT, 20)"}, + {dataType: "VECTOR(float, 20)", expectedSql: "VECTOR(FLOAT, 20)"}, + } + + nilTestCases := func() []datatypes.DataType { + var a *datatypes.ArrayDataType + var b *datatypes.BinaryDataType + var c *datatypes.BooleanDataType + var d *datatypes.DateDataType + var e *datatypes.FloatDataType + var f *datatypes.GeographyDataType + var g *datatypes.GeometryDataType + var h *datatypes.NumberDataType + var i *datatypes.ObjectDataType + var j *datatypes.TextDataType + var k *datatypes.TimeDataType + var l *datatypes.TimestampLtzDataType + var m *datatypes.TimestampNtzDataType + var n *datatypes.TimestampTzDataType + var o *datatypes.VariantDataType + var p *datatypes.VectorDataType + + return []datatypes.DataType{a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p} + }() + t.Run("test data type empty", func(t *testing.T) { + opts := dataTypeTestHelper{} + + s, err := structToSQL(opts) + + require.NoError(t, err) + assert.Equal(t, "", s) + }) + + for _, tc := range nilTestCases { + tc := tc + t.Run(fmt.Sprintf(`test for nil data type "%s"`, reflect.TypeOf(tc)), func(t *testing.T) { + opts := dataTypeTestHelper{ + DataType: tc, + } + + s, err := structToSQL(opts) + + require.NoError(t, err) + assert.Equal(t, "", s) + }) + } + + for _, tc := range dataTypes { + tc := tc + t.Run(fmt.Sprintf(`cheking building SQL for data type "%s, expecting "%s"`, tc.dataType, tc.expectedSql), func(t *testing.T) { + dataType, err := datatypes.ParseDataType(tc.dataType) + require.NoError(t, err) + + opts := dataTypeTestHelper{ + DataType: dataType, + } + + s, err := structToSQL(opts) + + require.NoError(t, err) + assert.Equal(t, tc.expectedSql, s) + }) + } +} diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index 44bb8b898a..5c19d66af4 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -44,9 +44,9 @@ func TestInt_CreateFunctions(t *testing.T) { } }` target := fmt.Sprintf("@~/tf-%d.jar", time.Now().Unix()) - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeVARCHAR).WithDefaultValue("'abc'") + argument := sdk.NewFunctionArgumentRequest("x", nil).WithDefaultValue("'abc'").WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForJavaFunctionRequest(id.SchemaObjectId(), *returns, "TestFunc.echoVarchar"). WithOrReplace(true). WithArguments([]sdk.FunctionArgumentRequest{*argument}). @@ -77,9 +77,9 @@ func TestInt_CreateFunctions(t *testing.T) { return result; }` - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeFloat) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewFunctionArgumentRequest("d", sdk.DataTypeFloat) + argument := sdk.NewFunctionArgumentRequest("d", nil).WithArgDataTypeOld(sdk.DataTypeFloat) request := sdk.NewCreateForJavascriptFunctionRequest(id.SchemaObjectId(), *returns, definition). WithOrReplace(true). WithArguments([]sdk.FunctionArgumentRequest{*argument}). @@ -100,9 +100,9 @@ func TestInt_CreateFunctions(t *testing.T) { definition := ` def dump(i): print("Hello World!")` - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeVariant) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVariant) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewFunctionArgumentRequest("i", sdk.DataTypeNumber) + argument := sdk.NewFunctionArgumentRequest("i", nil).WithArgDataTypeOld(sdk.DataTypeNumber) request := sdk.NewCreateForPythonFunctionRequest(id.SchemaObjectId(), *returns, "3.8", "dump"). WithOrReplace(true). WithArguments([]sdk.FunctionArgumentRequest{*argument}). @@ -127,8 +127,9 @@ def dump(i): } }` - argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeVARCHAR) - request := sdk.NewCreateForScalaFunctionRequest(id.SchemaObjectId(), sdk.DataTypeVARCHAR, "Echo.echoVarchar"). + argument := sdk.NewFunctionArgumentRequest("x", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + request := sdk.NewCreateForScalaFunctionRequest(id.SchemaObjectId(), nil, "Echo.echoVarchar"). + WithResultDataTypeOld(sdk.DataTypeVARCHAR). WithOrReplace(true). WithArguments([]sdk.FunctionArgumentRequest{*argument}). WithRuntimeVersion("2.12"). @@ -148,9 +149,9 @@ def dump(i): definition := "3.141592654::FLOAT" - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeFloat) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeFloat) + argument := sdk.NewFunctionArgumentRequest("x", nil).WithArgDataTypeOld(sdk.DataTypeFloat) request := sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). WithArguments([]sdk.FunctionArgumentRequest{*argument}). WithOrReplace(true). @@ -170,7 +171,7 @@ def dump(i): definition := "3.141592654::FLOAT" - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeFloat) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) request := sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). WithOrReplace(true). @@ -209,7 +210,7 @@ func TestInt_OtherFunctions(t *testing.T) { assert.Equal(t, 0, function.MaxNumArguments) } assert.NotEmpty(t, function.ArgumentsRaw) - assert.NotEmpty(t, function.Arguments) + assert.NotEmpty(t, function.ArgumentsOld) assert.NotEmpty(t, function.Description) assert.NotEmpty(t, function.CatalogName) assert.Equal(t, false, function.IsTableFunction) @@ -241,12 +242,12 @@ func TestInt_OtherFunctions(t *testing.T) { definition := "3.141592654::FLOAT" - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeFloat) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) request := sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition). WithOrReplace(true) if withArguments { - argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeFloat) + argument := sdk.NewFunctionArgumentRequest("x", nil).WithArgDataTypeOld(sdk.DataTypeFloat) request = request.WithArguments([]sdk.FunctionArgumentRequest{*argument}) } err := client.Functions.CreateForSQL(ctx, request) @@ -438,11 +439,11 @@ func TestInt_FunctionsShowByID(t *testing.T) { t.Helper() definition := "3.141592654::FLOAT" - dt := sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeFloat) + dt := sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeFloat) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) request := sdk.NewCreateForSQLFunctionRequest(id.SchemaObjectId(), *returns, definition).WithOrReplace(true) - argument := sdk.NewFunctionArgumentRequest("x", sdk.DataTypeFloat) + argument := sdk.NewFunctionArgumentRequest("x", nil).WithArgDataTypeOld(sdk.DataTypeFloat) request = request.WithArguments([]sdk.FunctionArgumentRequest{*argument}) err := client.Functions.CreateForSQL(ctx, request) require.NoError(t, err) @@ -497,41 +498,42 @@ func TestInt_FunctionsShowByID(t *testing.T) { require.Equal(t, *e, *es) }) - t.Run("function returns non detailed data types of arguments", func(t *testing.T) { + // TODO [SNOW-1348103]: remove with old function removal for V1 + t.Run("function returns non detailed data types of arguments - old data types", func(t *testing.T) { // This test proves that every detailed data types (e.g. VARCHAR(20) and NUMBER(10, 0)) are generalized // on Snowflake side (to e.g. VARCHAR and NUMBER) and that sdk.ToDataType mapping function maps detailed types // correctly to their generalized counterparts (same as in Snowflake). id := testClientHelper().Ids.RandomSchemaObjectIdentifier() args := []sdk.FunctionArgumentRequest{ - *sdk.NewFunctionArgumentRequest("A", "NUMBER(2, 0)"), - *sdk.NewFunctionArgumentRequest("B", "DECIMAL"), - *sdk.NewFunctionArgumentRequest("C", "INTEGER"), - *sdk.NewFunctionArgumentRequest("D", sdk.DataTypeFloat), - *sdk.NewFunctionArgumentRequest("E", "DOUBLE"), - *sdk.NewFunctionArgumentRequest("F", "VARCHAR(20)"), - *sdk.NewFunctionArgumentRequest("G", "CHAR"), - *sdk.NewFunctionArgumentRequest("H", sdk.DataTypeString), - *sdk.NewFunctionArgumentRequest("I", "TEXT"), - *sdk.NewFunctionArgumentRequest("J", sdk.DataTypeBinary), - *sdk.NewFunctionArgumentRequest("K", "VARBINARY"), - *sdk.NewFunctionArgumentRequest("L", sdk.DataTypeBoolean), - *sdk.NewFunctionArgumentRequest("M", sdk.DataTypeDate), - *sdk.NewFunctionArgumentRequest("N", "DATETIME"), - *sdk.NewFunctionArgumentRequest("O", sdk.DataTypeTime), - *sdk.NewFunctionArgumentRequest("R", sdk.DataTypeTimestampLTZ), - *sdk.NewFunctionArgumentRequest("S", sdk.DataTypeTimestampNTZ), - *sdk.NewFunctionArgumentRequest("T", sdk.DataTypeTimestampTZ), - *sdk.NewFunctionArgumentRequest("U", sdk.DataTypeVariant), - *sdk.NewFunctionArgumentRequest("V", sdk.DataTypeObject), - *sdk.NewFunctionArgumentRequest("W", sdk.DataTypeArray), - *sdk.NewFunctionArgumentRequest("X", sdk.DataTypeGeography), - *sdk.NewFunctionArgumentRequest("Y", sdk.DataTypeGeometry), - *sdk.NewFunctionArgumentRequest("Z", "VECTOR(INT, 16)"), + *sdk.NewFunctionArgumentRequest("A", nil).WithArgDataTypeOld("NUMBER(2, 0)"), + *sdk.NewFunctionArgumentRequest("B", nil).WithArgDataTypeOld("DECIMAL"), + *sdk.NewFunctionArgumentRequest("C", nil).WithArgDataTypeOld("INTEGER"), + *sdk.NewFunctionArgumentRequest("D", nil).WithArgDataTypeOld(sdk.DataTypeFloat), + *sdk.NewFunctionArgumentRequest("E", nil).WithArgDataTypeOld("DOUBLE"), + *sdk.NewFunctionArgumentRequest("F", nil).WithArgDataTypeOld("VARCHAR(20)"), + *sdk.NewFunctionArgumentRequest("G", nil).WithArgDataTypeOld("CHAR"), + *sdk.NewFunctionArgumentRequest("H", nil).WithArgDataTypeOld(sdk.DataTypeString), + *sdk.NewFunctionArgumentRequest("I", nil).WithArgDataTypeOld("TEXT"), + *sdk.NewFunctionArgumentRequest("J", nil).WithArgDataTypeOld(sdk.DataTypeBinary), + *sdk.NewFunctionArgumentRequest("K", nil).WithArgDataTypeOld("VARBINARY"), + *sdk.NewFunctionArgumentRequest("L", nil).WithArgDataTypeOld(sdk.DataTypeBoolean), + *sdk.NewFunctionArgumentRequest("M", nil).WithArgDataTypeOld(sdk.DataTypeDate), + *sdk.NewFunctionArgumentRequest("N", nil).WithArgDataTypeOld("DATETIME"), + *sdk.NewFunctionArgumentRequest("O", nil).WithArgDataTypeOld(sdk.DataTypeTime), + *sdk.NewFunctionArgumentRequest("R", nil).WithArgDataTypeOld(sdk.DataTypeTimestampLTZ), + *sdk.NewFunctionArgumentRequest("S", nil).WithArgDataTypeOld(sdk.DataTypeTimestampNTZ), + *sdk.NewFunctionArgumentRequest("T", nil).WithArgDataTypeOld(sdk.DataTypeTimestampTZ), + *sdk.NewFunctionArgumentRequest("U", nil).WithArgDataTypeOld(sdk.DataTypeVariant), + *sdk.NewFunctionArgumentRequest("V", nil).WithArgDataTypeOld(sdk.DataTypeObject), + *sdk.NewFunctionArgumentRequest("W", nil).WithArgDataTypeOld(sdk.DataTypeArray), + *sdk.NewFunctionArgumentRequest("X", nil).WithArgDataTypeOld(sdk.DataTypeGeography), + *sdk.NewFunctionArgumentRequest("Y", nil).WithArgDataTypeOld(sdk.DataTypeGeometry), + *sdk.NewFunctionArgumentRequest("Z", nil).WithArgDataTypeOld("VECTOR(INT, 16)"), } err := client.Functions.CreateForPython(ctx, sdk.NewCreateForPythonFunctionRequest( id, - *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.DataTypeVariant)), + *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVariant)), "3.8", "add", ). @@ -542,7 +544,7 @@ func TestInt_FunctionsShowByID(t *testing.T) { dataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - dataType, err := datatypes.ParseDataType(string(arg.ArgDataType)) + dataType, err := datatypes.ParseDataType(string(arg.ArgDataTypeOld)) require.NoError(t, err) dataTypes[i] = sdk.LegacyDataTypeFrom(dataType) } @@ -550,6 +552,80 @@ func TestInt_FunctionsShowByID(t *testing.T) { function, err := client.Functions.ShowByID(ctx, idWithArguments) require.NoError(t, err) - require.Equal(t, dataTypes, function.Arguments) + require.Equal(t, dataTypes, function.ArgumentsOld) }) + + // This test shows behavior of detailed types (e.g. VARCHAR(20) and NUMBER(10, 0)) on Snowflake side for functions. + // For SHOW, data type is generalized both for argument and return type (to e.g. VARCHAR and NUMBER). + // FOR DESCRIBE, data type is generalized for argument and works weirdly for the return type: type is generalized to the canonical one, but we also get the attributes. + for _, tc := range []string{ + "NUMBER(36, 5)", + "NUMBER(36)", + "NUMBER", + "DECIMAL", + "INTEGER", + "FLOAT", + "DOUBLE", + "VARCHAR", + "VARCHAR(20)", + "CHAR", + "CHAR(10)", + "TEXT", + "BINARY", + "BINARY(1000)", + "VARBINARY", + "BOOLEAN", + "DATE", + "DATETIME", + "TIME", + "TIMESTAMP_LTZ", + "TIMESTAMP_NTZ", + "TIMESTAMP_TZ", + "VARIANT", + "OBJECT", + "ARRAY", + "GEOGRAPHY", + "GEOMETRY", + "VECTOR(INT, 16)", + "VECTOR(FLOAT, 8)", + } { + tc := tc + t.Run(fmt.Sprintf("function returns non detailed data types of arguments for %s", tc), func(t *testing.T) { + id := testClientHelper().Ids.RandomSchemaObjectIdentifier() + argName := "A" + dataType, err := datatypes.ParseDataType(tc) + require.NoError(t, err) + args := []sdk.FunctionArgumentRequest{ + *sdk.NewFunctionArgumentRequest(argName, dataType), + } + + err = client.Functions.CreateForPython(ctx, sdk.NewCreateForPythonFunctionRequest( + id, + *sdk.NewFunctionReturnsRequest().WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(dataType)), + "3.8", + "add", + ). + WithArguments(args). + WithFunctionDefinition(fmt.Sprintf("def add(%[1]s): %[1]s", argName)), + ) + require.NoError(t, err) + + oldDataType := sdk.LegacyDataTypeFrom(dataType) + idWithArguments := sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), oldDataType) + + function, err := client.Functions.ShowByID(ctx, idWithArguments) + require.NoError(t, err) + assert.Equal(t, []sdk.DataType{oldDataType}, function.ArgumentsOld) + assert.Equal(t, fmt.Sprintf("%[1]s(%[2]s) RETURN %[2]s", id.Name(), oldDataType), function.ArgumentsRaw) + + details, err := client.Functions.Describe(ctx, idWithArguments) + require.NoError(t, err) + pairs := make(map[string]string) + for _, detail := range details { + pairs[detail.Property] = detail.Value + } + assert.Equal(t, fmt.Sprintf("(%s %s)", argName, oldDataType), pairs["signature"]) + assert.Equal(t, dataType.Canonical(), pairs["returns"]) + }) + } } diff --git a/pkg/sdk/testint/procedures_integration_test.go b/pkg/sdk/testint/procedures_integration_test.go index 309a3db9f9..4543791c4d 100644 --- a/pkg/sdk/testint/procedures_integration_test.go +++ b/pkg/sdk/testint/procedures_integration_test.go @@ -7,6 +7,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -46,9 +47,9 @@ func TestInt_CreateProcedures(t *testing.T) { } }` - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewProcedureArgumentRequest("input", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("input", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} request := sdk.NewCreateForJavaProcedureRequest(id.SchemaObjectId(), *returns, "11", packages, "FileReader.execute"). WithOrReplace(true). @@ -77,13 +78,13 @@ func TestInt_CreateProcedures(t *testing.T) { return filteredRows; } }` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("table_name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("table_name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} request := sdk.NewCreateForJavaProcedureRequest(id.SchemaObjectId(), *returns, "11", packages, "Filter.filterByRole"). WithOrReplace(true). @@ -114,8 +115,9 @@ func TestInt_CreateProcedures(t *testing.T) { catch (err) { return "Failed: " + err; // Return a success/error indicator. }` - argument := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", sdk.DataTypeFloat) - request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.DataTypeString, definition). + argument := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", nil).WithArgDataTypeOld(sdk.DataTypeFloat) + request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), nil, definition). + WithResultDataTypeOld(sdk.DataTypeString). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). WithNullInputBehavior(*sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorStrict)). WithExecuteAs(*sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) @@ -134,7 +136,7 @@ func TestInt_CreateProcedures(t *testing.T) { id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name) definition := `return 3.1415926;` - request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.DataTypeFloat, definition).WithNotNull(true).WithOrReplace(true) + request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), nil, definition).WithResultDataTypeOld(sdk.DataTypeFloat).WithNotNull(true).WithOrReplace(true) err := client.Procedures.CreateForJavaScript(ctx, request) require.NoError(t, err) t.Cleanup(cleanupProcedureHandle(id)) @@ -160,9 +162,9 @@ func TestInt_CreateProcedures(t *testing.T) { return new String(input.readAllBytes(), StandardCharsets.UTF_8) } }` - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewProcedureArgumentRequest("input", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("input", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} request := sdk.NewCreateForScalaProcedureRequest(id.SchemaObjectId(), *returns, "2.12", packages, "FileReader.execute"). WithOrReplace(true). @@ -192,13 +194,13 @@ func TestInt_CreateProcedures(t *testing.T) { return filteredRows } }` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("table_name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("table_name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} request := sdk.NewCreateForScalaProcedureRequest(id.SchemaObjectId(), *returns, "2.12", packages, "Filter.filterByRole"). WithOrReplace(true). @@ -225,9 +227,9 @@ def joblib_multiprocessing(session, i): result = joblib.Parallel(n_jobs=-1)(joblib.delayed(sqrt)(i ** 2) for i in range(10)) return str(result)` - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeString) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeString) returns := sdk.NewProcedureReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewProcedureArgumentRequest("i", sdk.DataTypeInt) + argument := sdk.NewProcedureArgumentRequest("i", nil).WithArgDataTypeOld(sdk.DataTypeInt) packages := []sdk.ProcedurePackageRequest{ *sdk.NewProcedurePackageRequest("snowflake-snowpark-python"), *sdk.NewProcedurePackageRequest("joblib"), @@ -255,13 +257,13 @@ from snowflake.snowpark.functions import col def filter_by_role(session, table_name, role): df = session.table(table_name) return df.filter(col("role") == role)` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("table_name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("table_name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("snowflake-snowpark-python")} request := sdk.NewCreateForPythonProcedureRequest(id.SchemaObjectId(), *returns, "3.8", packages, "filter_by_role"). WithOrReplace(true). @@ -286,9 +288,9 @@ def filter_by_role(session, table_name, role): RETURN message; END;` - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) - argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("message", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithOrReplace(true). // Suddenly this is erroring out, when it used to not have an problem. Must be an error with the Snowflake API. @@ -318,11 +320,11 @@ def filter_by_role(session, table_name, role): BEGIN RETURN TABLE(res); END;` - column1 := sdk.NewProcedureColumnRequest("id", "INTEGER") - column2 := sdk.NewProcedureColumnRequest("price", "NUMBER(12,2)") + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld("INTEGER") + column2 := sdk.NewProcedureColumnRequest("price", nil).WithColumnDataTypeOld("NUMBER(12,2)") returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2}) returns := sdk.NewProcedureSQLReturnsRequest().WithTable(*returnsTable) - argument := sdk.NewProcedureArgumentRequest("id", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("id", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithOrReplace(true). // SNOW-1051627 todo: uncomment once null input behavior working again @@ -355,7 +357,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { assert.Equal(t, false, procedure.IsAnsi) assert.Equal(t, 1, procedure.MinNumArguments) assert.Equal(t, 1, procedure.MaxNumArguments) - assert.NotEmpty(t, procedure.Arguments) + assert.NotEmpty(t, procedure.ArgumentsOld) assert.NotEmpty(t, procedure.ArgumentsRaw) assert.NotEmpty(t, procedure.Description) assert.NotEmpty(t, procedure.CatalogName) @@ -382,9 +384,9 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { RETURN message; END;` id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) - argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("message", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithSecure(true). WithOrReplace(true). @@ -498,7 +500,7 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { require.Equal(t, 0, len(procedures)) }) - t.Run("describe function for SQL", func(t *testing.T) { + t.Run("describe procedure for SQL", func(t *testing.T) { f := createProcedureForSQLHandle(t, true) id := f.ID() @@ -520,9 +522,9 @@ func TestInt_OtherProcedureFunctions(t *testing.T) { RETURN message; END;` id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) - argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("message", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). @@ -575,9 +577,9 @@ func TestInt_CallProcedure(t *testing.T) { RETURN message; END;` id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) - argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("message", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithSecure(true). WithOrReplace(true). @@ -619,13 +621,13 @@ func TestInt_CallProcedure(t *testing.T) { return filteredRows; } }` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} request := sdk.NewCreateForJavaProcedureRequest(id.SchemaObjectId(), *returns, "11", packages, "Filter.filterByRole"). WithOrReplace(true). @@ -658,8 +660,8 @@ func TestInt_CallProcedure(t *testing.T) { }` returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} request := sdk.NewCreateForScalaProcedureRequest(id.SchemaObjectId(), *returns, "2.12", packages, "Filter.filterByRole"). WithOrReplace(true). @@ -690,8 +692,9 @@ func TestInt_CallProcedure(t *testing.T) { catch (err) { return "Failed: " + err; // Return a success/error indicator. }` - arg := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", sdk.DataTypeFloat) - request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.DataTypeString, definition). + arg := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", nil).WithArgDataTypeOld(sdk.DataTypeFloat) + request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), nil, definition). + WithResultDataTypeOld(sdk.DataTypeString). WithOrReplace(true). WithArguments([]sdk.ProcedureArgumentRequest{*arg}). WithNullInputBehavior(*sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorStrict)). @@ -710,7 +713,7 @@ func TestInt_CallProcedure(t *testing.T) { id := sdk.NewSchemaObjectIdentifierWithArguments(databaseId.Name(), schemaId.Name(), name) definition := `return 3.1415926;` - request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.DataTypeFloat, definition).WithNotNull(true).WithOrReplace(true) + request := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), nil, definition).WithResultDataTypeOld(sdk.DataTypeFloat).WithNotNull(true).WithOrReplace(true) err := client.Procedures.CreateForJavaScript(ctx, request) require.NoError(t, err) t.Cleanup(cleanupProcedureHandle(id)) @@ -730,8 +733,8 @@ def filter_by_role(session, name, role): return df.filter(col("role") == role)` returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("snowflake-snowpark-python")} request := sdk.NewCreateForPythonProcedureRequest(id.SchemaObjectId(), *returns, "3.8", packages, "filter_by_role"). WithOrReplace(true). @@ -783,13 +786,13 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { return filteredRows; } }` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} request := sdk.NewCreateAndCallForJavaProcedureRequest(name, *returns, "11", packages, "Filter.filterByRole", name). @@ -816,13 +819,13 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { return filteredRows } }` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} request := sdk.NewCreateAndCallForScalaProcedureRequest(name, *returns, "2.12", packages, "Filter.filterByRole", name). @@ -849,8 +852,9 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { catch (err) { return "Failed: " + err; // Return a success/error indicator. }` - arg := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", sdk.DataTypeFloat) - request := sdk.NewCreateAndCallForJavaScriptProcedureRequest(name, sdk.DataTypeString, definition, name). + arg := sdk.NewProcedureArgumentRequest("FLOAT_PARAM1", nil).WithArgDataTypeOld(sdk.DataTypeFloat) + request := sdk.NewCreateAndCallForJavaScriptProcedureRequest(name, nil, definition, name). + WithResultDataTypeOld(sdk.DataTypeString). WithArguments([]sdk.ProcedureArgumentRequest{*arg}). WithNullInputBehavior(*sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorStrict)). WithCallArguments([]string{"5.14::FLOAT"}) @@ -864,7 +868,7 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { name := sdk.NewAccountObjectIdentifier("sp_pi") definition := `return 3.1415926;` - request := sdk.NewCreateAndCallForJavaScriptProcedureRequest(name, sdk.DataTypeFloat, definition, name).WithNotNull(true) + request := sdk.NewCreateAndCallForJavaScriptProcedureRequest(name, nil, definition, name).WithResultDataTypeOld(sdk.DataTypeFloat).WithNotNull(true) err := client.Procedures.CreateAndCallForJavaScript(ctx, request) require.NoError(t, err) }) @@ -876,9 +880,9 @@ func TestInt_CreateAndCallProcedures(t *testing.T) { END;` name := testClientHelper().Ids.RandomAccountObjectIdentifier() - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureReturnsRequest().WithResultDataType(*dt) - argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("message", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateAndCallForSQLProcedureRequest(name, *returns, definition, name). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). WithCallArguments([]string{"message => 'hi'"}) @@ -897,8 +901,8 @@ def filter_by_role(session, name, role): return df.filter(col("role") == role)` returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("snowflake-snowpark-python")} ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} request := sdk.NewCreateAndCallForPythonProcedureRequest(name, *returns, "3.8", packages, "filter_by_role", name). @@ -922,13 +926,13 @@ def filter_by_role(session, name, role): return filteredRows; } }` - column1 := sdk.NewProcedureColumnRequest("id", sdk.DataTypeNumber) - column2 := sdk.NewProcedureColumnRequest("name", sdk.DataTypeVARCHAR) - column3 := sdk.NewProcedureColumnRequest("role", sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld(sdk.DataTypeNumber) + column2 := sdk.NewProcedureColumnRequest("name", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) + column3 := sdk.NewProcedureColumnRequest("role", nil).WithColumnDataTypeOld(sdk.DataTypeVARCHAR) returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) returns := sdk.NewProcedureReturnsRequest().WithTable(*returnsTable) - arg1 := sdk.NewProcedureArgumentRequest("name", sdk.DataTypeVARCHAR) - arg2 := sdk.NewProcedureArgumentRequest("role", sdk.DataTypeVARCHAR) + arg1 := sdk.NewProcedureArgumentRequest("name", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) + arg2 := sdk.NewProcedureArgumentRequest("role", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("com.snowflake:snowpark:latest")} ca := []string{fmt.Sprintf(`'%s'`, tid.FullyQualifiedName()), "'dev'"} @@ -967,9 +971,9 @@ func TestInt_ProceduresShowByID(t *testing.T) { BEGIN RETURN message; END;` - dt := sdk.NewProcedureReturnsResultDataTypeRequest(sdk.DataTypeVARCHAR) + dt := sdk.NewProcedureReturnsResultDataTypeRequest(nil).WithResultDataTypeOld(sdk.DataTypeVARCHAR) returns := sdk.NewProcedureSQLReturnsRequest().WithResultDataType(*dt).WithNotNull(true) - argument := sdk.NewProcedureArgumentRequest("message", sdk.DataTypeVARCHAR) + argument := sdk.NewProcedureArgumentRequest("message", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequest(id.SchemaObjectId(), *returns, definition). WithArguments([]sdk.ProcedureArgumentRequest{*argument}). WithExecuteAs(*sdk.ExecuteAsPointer(sdk.ExecuteAsCaller)) @@ -1019,4 +1023,81 @@ func TestInt_ProceduresShowByID(t *testing.T) { require.NoError(t, err) require.Equal(t, *e, *es) }) + + // This test shows behavior of detailed types (e.g. VARCHAR(20) and NUMBER(10, 0)) on Snowflake side for procedures. + // For SHOW, data type is generalized both for argument and return type (to e.g. VARCHAR and NUMBER). + // FOR DESCRIBE, data type is generalized for argument and works weirdly for the return type: type is generalized to the canonical one, but we also get the attributes. + for _, tc := range []string{ + "NUMBER(36, 5)", + "NUMBER(36)", + "NUMBER", + "DECIMAL", + "INTEGER", + "FLOAT", + "DOUBLE", + "VARCHAR", + "VARCHAR(20)", + "CHAR", + "CHAR(10)", + "TEXT", + "BINARY", + "BINARY(1000)", + "VARBINARY", + "BOOLEAN", + "DATE", + "DATETIME", + "TIME", + "TIMESTAMP_LTZ", + "TIMESTAMP_NTZ", + "TIMESTAMP_TZ", + "VARIANT", + "OBJECT", + "ARRAY", + "GEOGRAPHY", + "GEOMETRY", + "VECTOR(INT, 16)", + "VECTOR(FLOAT, 8)", + } { + tc := tc + t.Run(fmt.Sprintf("procedure returns non detailed data types of arguments for %s", tc), func(t *testing.T) { + procName := "add" + argName := "A" + dataType, err := datatypes.ParseDataType(tc) + require.NoError(t, err) + args := []sdk.ProcedureArgumentRequest{ + *sdk.NewProcedureArgumentRequest(argName, dataType), + } + oldDataType := sdk.LegacyDataTypeFrom(dataType) + idWithArguments := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(oldDataType) + + packages := []sdk.ProcedurePackageRequest{*sdk.NewProcedurePackageRequest("snowflake-snowpark-python")} + definition := fmt.Sprintf("def add(%[1]s): %[1]s", argName) + + err = client.Procedures.CreateForPython(ctx, sdk.NewCreateForPythonProcedureRequest( + idWithArguments.SchemaObjectId(), + *sdk.NewProcedureReturnsRequest().WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(dataType)), + "3.8", + packages, + procName, + ). + WithArguments(args). + WithProcedureDefinition(definition), + ) + require.NoError(t, err) + + procedure, err := client.Procedures.ShowByID(ctx, idWithArguments) + require.NoError(t, err) + assert.Equal(t, []sdk.DataType{oldDataType}, procedure.ArgumentsOld) + assert.Equal(t, fmt.Sprintf("%[1]s(%[2]s) RETURN %[2]s", idWithArguments.Name(), oldDataType), procedure.ArgumentsRaw) + + details, err := client.Procedures.Describe(ctx, idWithArguments) + require.NoError(t, err) + pairs := make(map[string]string) + for _, detail := range details { + pairs[detail.Property] = detail.Value + } + assert.Equal(t, fmt.Sprintf("(%s %s)", argName, oldDataType), pairs["signature"]) + assert.Equal(t, dataType.Canonical(), pairs["returns"]) + }) + } }