Skip to content

Commit

Permalink
fix: fix some issues about database generator (#742)
Browse files Browse the repository at this point in the history
  • Loading branch information
liu-hm19 committed Jan 15, 2024
1 parent 0dfe9c5 commit f7c8fed
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 33 deletions.
12 changes: 9 additions & 3 deletions pkg/modules/generators/accessories/mysql/alicloud_rds.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package mysql

import (
"fmt"
"os"
"strings"

v1 "k8s.io/api/core/v1"
Expand All @@ -14,6 +16,7 @@ import (

const (
defaultAlicloudProviderURL = "registry.terraform.io/aliyun/alicloud/1.209.1"
alicloudRegionEnv = "ALICLOUD_REGION"
alicloudDBInstance = "alicloud_db_instance"
alicloudDBConnection = "alicloud_db_connection"
alicloudRDSAccount = "alicloud_rds_account"
Expand Down Expand Up @@ -58,9 +61,12 @@ func (g *mysqlGenerator) generateAlicloudResources(db *mysql.MySQL, spec *apiv1.
}

// Get the alicloud provider region, and the region of the alicloud provider must be set.
alicloudProviderRegion, err := inputs.GetProviderRegion(g.tfConfigs[inputs.AlicloudProvider])
if err != nil {
return nil, err
var alicloudProviderRegion string
if alicloudProviderRegion = inputs.GetProviderRegion(g.tfConfigs[inputs.AlicloudProvider]); alicloudProviderRegion == "" {
alicloudProviderRegion = os.Getenv(alicloudRegionEnv)
}
if alicloudProviderRegion == "" {
return nil, fmt.Errorf("alicloud provider region should not be empty")
}

// Build alicloud_db_instance.
Expand Down
11 changes: 8 additions & 3 deletions pkg/modules/generators/accessories/mysql/aws_rds.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mysql

import (
"fmt"
"os"

v1 "k8s.io/api/core/v1"

Expand All @@ -14,6 +15,7 @@ import (

const (
defaultAWSProviderURL = "registry.terraform.io/hashicorp/aws/5.0.1"
awsRegionEnv = "AWS_REGION"
awsSecurityGroup = "aws_security_group"
awsDBInstance = "aws_db_instance"
)
Expand Down Expand Up @@ -62,9 +64,12 @@ func (g *mysqlGenerator) generateAWSResources(db *mysql.MySQL, spec *apiv1.Inten
}

// Get the aws provider region, and the region of the aws provider must be set.
awsProviderRegion, err := inputs.GetProviderRegion(g.tfConfigs[inputs.AWSProvider])
if err != nil {
return nil, err
var awsProviderRegion string
if awsProviderRegion = inputs.GetProviderRegion(g.tfConfigs[inputs.AWSProvider]); awsProviderRegion == "" {
awsProviderRegion = os.Getenv(awsRegionEnv)
}
if awsProviderRegion == "" {
return nil, fmt.Errorf("aws provider region should not be empty")
}

// Build random_password for aws_db_instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (g *mysqlGenerator) Generate(spec *apiv1.Intent) error {
return err
}

switch providerType {
switch strings.ToLower(providerType) {
case "aws":
secret, err = g.generateAWSResources(db, spec)
case "alicloud":
Expand Down
12 changes: 9 additions & 3 deletions pkg/modules/generators/accessories/postgres/alicloud_rds.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package postgres

import (
"fmt"
"os"
"strings"

v1 "k8s.io/api/core/v1"
Expand All @@ -12,6 +14,7 @@ import (

const (
defaultAlicloudProviderURL = "registry.terraform.io/aliyun/alicloud/1.209.1"
alicloudRegionEnv = "ALICLOUD_REGION"
alicloudDBInstance = "alicloud_db_instance"
alicloudDBConnection = "alicloud_db_connection"
alicloudRDSAccount = "alicloud_rds_account"
Expand Down Expand Up @@ -56,9 +59,12 @@ func (g *postgresGenerator) generateAlicloudResources(db *postgres.PostgreSQL, s
}

// Get the alicloud provider region, and the region of the alicloud provider must be set.
alicloudProviderRegion, err := inputs.GetProviderRegion(g.tfConfigs[inputs.AlicloudProvider])
if err != nil {
return nil, err
var alicloudProviderRegion string
if alicloudProviderRegion = inputs.GetProviderRegion(g.tfConfigs[inputs.AlicloudProvider]); alicloudProviderRegion == "" {
alicloudProviderRegion = os.Getenv(alicloudRegionEnv)
}
if alicloudProviderRegion == "" {
return nil, fmt.Errorf("alicloud provider region should not be empty")
}

// Build alicloud_db_instance.
Expand Down
11 changes: 8 additions & 3 deletions pkg/modules/generators/accessories/postgres/aws_rds.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package postgres

import (
"fmt"
"os"

v1 "k8s.io/api/core/v1"
apiv1 "kusionstack.io/kusion/pkg/apis/core/v1"
Expand All @@ -12,6 +13,7 @@ import (

const (
defaultAWSProviderURL = "registry.terraform.io/hashicorp/aws/5.0.1"
awsRegionEnv = "AWS_REGION"
awsSecurityGroup = "aws_security_group"
awsDBInstance = "aws_db_instance"
)
Expand Down Expand Up @@ -60,9 +62,12 @@ func (g *postgresGenerator) generateAWSResources(db *postgres.PostgreSQL, spec *
}

// Get the aws provider region, and the region of the aws provider must be set.
awsProviderRegion, err := inputs.GetProviderRegion(g.tfConfigs[inputs.AWSProvider])
if err != nil {
return nil, err
var awsProviderRegion string
if awsProviderRegion = inputs.GetProviderRegion(g.tfConfigs[inputs.AWSProvider]); awsProviderRegion == "" {
awsProviderRegion = os.Getenv(awsRegionEnv)
}
if awsProviderRegion == "" {
return nil, fmt.Errorf("aws provider region should not be empty")
}

// Build random_password for aws_db_instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (g *postgresGenerator) Generate(spec *apiv1.Intent) error {
return err
}

switch providerType {
switch strings.ToLower(providerType) {
case "aws":
secret, err = g.generateAWSResources(db, spec)
case "alicloud":
Expand Down
7 changes: 3 additions & 4 deletions pkg/modules/inputs/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
const (
errInvalidProviderSource = "invalid provider source: %s"
errEmptyProviderVersion = "empty provider version"
errEmptyProviderRegion = "empty provider region for source: %s"
)

const (
Expand Down Expand Up @@ -75,11 +74,11 @@ func GetProviderURL(providerConfig *apiv1.ProviderConfig) (string, error) {
}

// GetProviderRegion returns the region of the terraform provider.
func GetProviderRegion(providerConfig *apiv1.ProviderConfig) (string, error) {
func GetProviderRegion(providerConfig *apiv1.ProviderConfig) string {
region, ok := providerConfig.GenericConfig["region"]
if !ok {
return "", fmt.Errorf(errEmptyProviderRegion, providerConfig.Source)
return ""
}

return region.(string), nil
return region.(string)
}
22 changes: 7 additions & 15 deletions pkg/modules/inputs/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ func TestGetProviderURL(t *testing.T) {

func TestGetProviderRegion(t *testing.T) {
tests := []struct {
name string
data *apiv1.ProviderConfig
expected string
expectedErr error
name string
data *apiv1.ProviderConfig
expected string
}{
{
name: "Valid Provider Config",
Expand All @@ -123,27 +122,20 @@ func TestGetProviderRegion(t *testing.T) {
"region": "us-east-1",
},
},
expected: "us-east-1",
expectedErr: nil,
expected: "us-east-1",
},
{
name: "Empty Provider Region",
data: &apiv1.ProviderConfig{
Source: "hashicorp/aws",
Version: "5.0.1",
},
expected: "",
expectedErr: fmt.Errorf(errEmptyProviderRegion, "hashicorp/aws"),
expected: "",
},
}

for _, test := range tests {
actual, actualErr := GetProviderRegion(test.data)
if test.expectedErr == nil {
assert.Equal(t, test.expected, actual)
assert.NoError(t, actualErr)
} else {
assert.ErrorContains(t, actualErr, test.expectedErr.Error())
}
actual := GetProviderRegion(test.data)
assert.Equal(t, test.expected, actual)
}
}

0 comments on commit f7c8fed

Please sign in to comment.