diff --git a/pkg/modules/generators/accessories/mysql/alicloud_rds.go b/pkg/modules/generators/accessories/mysql/alicloud_rds.go index 97010a19..6ef17c5f 100644 --- a/pkg/modules/generators/accessories/mysql/alicloud_rds.go +++ b/pkg/modules/generators/accessories/mysql/alicloud_rds.go @@ -1,6 +1,8 @@ package mysql import ( + "fmt" + "os" "strings" v1 "k8s.io/api/core/v1" @@ -14,6 +16,7 @@ import ( const ( defaultAlicloudProviderURL = "registry.terraform.io/aliyun/alicloud/1.209.1" + alicloudRegionEnv = "ALICLOUD_REGION" alicloudDBInstance = "alicloud_db_instance" alicloudDBConnection = "alicloud_db_connection" alicloudRDSAccount = "alicloud_rds_account" @@ -58,9 +61,12 @@ func (g *mysqlGenerator) generateAlicloudResources(db *mysql.MySQL, spec *apiv1. } // Get the alicloud provider region, and the region of the alicloud provider must be set. - alicloudProviderRegion, err := inputs.GetProviderRegion(g.tfConfigs[inputs.AlicloudProvider]) - if err != nil { - return nil, err + var alicloudProviderRegion string + if alicloudProviderRegion = inputs.GetProviderRegion(g.tfConfigs[inputs.AlicloudProvider]); alicloudProviderRegion == "" { + alicloudProviderRegion = os.Getenv(alicloudRegionEnv) + } + if alicloudProviderRegion == "" { + return nil, fmt.Errorf("alicloud provider region should not be empty") } // Build alicloud_db_instance. diff --git a/pkg/modules/generators/accessories/mysql/aws_rds.go b/pkg/modules/generators/accessories/mysql/aws_rds.go index 1ff2322d..e563fed9 100644 --- a/pkg/modules/generators/accessories/mysql/aws_rds.go +++ b/pkg/modules/generators/accessories/mysql/aws_rds.go @@ -2,6 +2,7 @@ package mysql import ( "fmt" + "os" v1 "k8s.io/api/core/v1" @@ -14,6 +15,7 @@ import ( const ( defaultAWSProviderURL = "registry.terraform.io/hashicorp/aws/5.0.1" + awsRegionEnv = "AWS_REGION" awsSecurityGroup = "aws_security_group" awsDBInstance = "aws_db_instance" ) @@ -62,9 +64,12 @@ func (g *mysqlGenerator) generateAWSResources(db *mysql.MySQL, spec *apiv1.Inten } // Get the aws provider region, and the region of the aws provider must be set. - awsProviderRegion, err := inputs.GetProviderRegion(g.tfConfigs[inputs.AWSProvider]) - if err != nil { - return nil, err + var awsProviderRegion string + if awsProviderRegion = inputs.GetProviderRegion(g.tfConfigs[inputs.AWSProvider]); awsProviderRegion == "" { + awsProviderRegion = os.Getenv(awsRegionEnv) + } + if awsProviderRegion == "" { + return nil, fmt.Errorf("aws provider region should not be empty") } // Build random_password for aws_db_instance. diff --git a/pkg/modules/generators/accessories/mysql/mysql_generator.go b/pkg/modules/generators/accessories/mysql/mysql_generator.go index f3760e41..9300913c 100644 --- a/pkg/modules/generators/accessories/mysql/mysql_generator.go +++ b/pkg/modules/generators/accessories/mysql/mysql_generator.go @@ -119,7 +119,7 @@ func (g *mysqlGenerator) Generate(spec *apiv1.Intent) error { return err } - switch providerType { + switch strings.ToLower(providerType) { case "aws": secret, err = g.generateAWSResources(db, spec) case "alicloud": diff --git a/pkg/modules/generators/accessories/postgres/alicloud_rds.go b/pkg/modules/generators/accessories/postgres/alicloud_rds.go index 7d0c09ec..07763aa7 100644 --- a/pkg/modules/generators/accessories/postgres/alicloud_rds.go +++ b/pkg/modules/generators/accessories/postgres/alicloud_rds.go @@ -1,6 +1,8 @@ package postgres import ( + "fmt" + "os" "strings" v1 "k8s.io/api/core/v1" @@ -12,6 +14,7 @@ import ( const ( defaultAlicloudProviderURL = "registry.terraform.io/aliyun/alicloud/1.209.1" + alicloudRegionEnv = "ALICLOUD_REGION" alicloudDBInstance = "alicloud_db_instance" alicloudDBConnection = "alicloud_db_connection" alicloudRDSAccount = "alicloud_rds_account" @@ -56,9 +59,12 @@ func (g *postgresGenerator) generateAlicloudResources(db *postgres.PostgreSQL, s } // Get the alicloud provider region, and the region of the alicloud provider must be set. - alicloudProviderRegion, err := inputs.GetProviderRegion(g.tfConfigs[inputs.AlicloudProvider]) - if err != nil { - return nil, err + var alicloudProviderRegion string + if alicloudProviderRegion = inputs.GetProviderRegion(g.tfConfigs[inputs.AlicloudProvider]); alicloudProviderRegion == "" { + alicloudProviderRegion = os.Getenv(alicloudRegionEnv) + } + if alicloudProviderRegion == "" { + return nil, fmt.Errorf("alicloud provider region should not be empty") } // Build alicloud_db_instance. diff --git a/pkg/modules/generators/accessories/postgres/aws_rds.go b/pkg/modules/generators/accessories/postgres/aws_rds.go index b6d3398a..150c5a9e 100644 --- a/pkg/modules/generators/accessories/postgres/aws_rds.go +++ b/pkg/modules/generators/accessories/postgres/aws_rds.go @@ -2,6 +2,7 @@ package postgres import ( "fmt" + "os" v1 "k8s.io/api/core/v1" apiv1 "kusionstack.io/kusion/pkg/apis/core/v1" @@ -12,6 +13,7 @@ import ( const ( defaultAWSProviderURL = "registry.terraform.io/hashicorp/aws/5.0.1" + awsRegionEnv = "AWS_REGION" awsSecurityGroup = "aws_security_group" awsDBInstance = "aws_db_instance" ) @@ -60,9 +62,12 @@ func (g *postgresGenerator) generateAWSResources(db *postgres.PostgreSQL, spec * } // Get the aws provider region, and the region of the aws provider must be set. - awsProviderRegion, err := inputs.GetProviderRegion(g.tfConfigs[inputs.AWSProvider]) - if err != nil { - return nil, err + var awsProviderRegion string + if awsProviderRegion = inputs.GetProviderRegion(g.tfConfigs[inputs.AWSProvider]); awsProviderRegion == "" { + awsProviderRegion = os.Getenv(awsRegionEnv) + } + if awsProviderRegion == "" { + return nil, fmt.Errorf("aws provider region should not be empty") } // Build random_password for aws_db_instance. diff --git a/pkg/modules/generators/accessories/postgres/postgres_generator.go b/pkg/modules/generators/accessories/postgres/postgres_generator.go index 9f81d139..af387835 100644 --- a/pkg/modules/generators/accessories/postgres/postgres_generator.go +++ b/pkg/modules/generators/accessories/postgres/postgres_generator.go @@ -119,7 +119,7 @@ func (g *postgresGenerator) Generate(spec *apiv1.Intent) error { return err } - switch providerType { + switch strings.ToLower(providerType) { case "aws": secret, err = g.generateAWSResources(db, spec) case "alicloud": diff --git a/pkg/modules/inputs/provider.go b/pkg/modules/inputs/provider.go index 884a061c..f57c7772 100644 --- a/pkg/modules/inputs/provider.go +++ b/pkg/modules/inputs/provider.go @@ -11,7 +11,6 @@ import ( const ( errInvalidProviderSource = "invalid provider source: %s" errEmptyProviderVersion = "empty provider version" - errEmptyProviderRegion = "empty provider region for source: %s" ) const ( @@ -75,11 +74,11 @@ func GetProviderURL(providerConfig *apiv1.ProviderConfig) (string, error) { } // GetProviderRegion returns the region of the terraform provider. -func GetProviderRegion(providerConfig *apiv1.ProviderConfig) (string, error) { +func GetProviderRegion(providerConfig *apiv1.ProviderConfig) string { region, ok := providerConfig.GenericConfig["region"] if !ok { - return "", fmt.Errorf(errEmptyProviderRegion, providerConfig.Source) + return "" } - return region.(string), nil + return region.(string) } diff --git a/pkg/modules/inputs/provider_test.go b/pkg/modules/inputs/provider_test.go index e1e67d7f..a7066bc7 100644 --- a/pkg/modules/inputs/provider_test.go +++ b/pkg/modules/inputs/provider_test.go @@ -109,10 +109,9 @@ func TestGetProviderURL(t *testing.T) { func TestGetProviderRegion(t *testing.T) { tests := []struct { - name string - data *apiv1.ProviderConfig - expected string - expectedErr error + name string + data *apiv1.ProviderConfig + expected string }{ { name: "Valid Provider Config", @@ -123,8 +122,7 @@ func TestGetProviderRegion(t *testing.T) { "region": "us-east-1", }, }, - expected: "us-east-1", - expectedErr: nil, + expected: "us-east-1", }, { name: "Empty Provider Region", @@ -132,18 +130,12 @@ func TestGetProviderRegion(t *testing.T) { Source: "hashicorp/aws", Version: "5.0.1", }, - expected: "", - expectedErr: fmt.Errorf(errEmptyProviderRegion, "hashicorp/aws"), + expected: "", }, } for _, test := range tests { - actual, actualErr := GetProviderRegion(test.data) - if test.expectedErr == nil { - assert.Equal(t, test.expected, actual) - assert.NoError(t, actualErr) - } else { - assert.ErrorContains(t, actualErr, test.expectedErr.Error()) - } + actual := GetProviderRegion(test.data) + assert.Equal(t, test.expected, actual) } }