Skip to content

Commit

Permalink
Add a check for permission to createVolume for EBS and creds validati…
Browse files Browse the repository at this point in the history
…on for aws (#1339)

* add func to check ebs volume create permission

* add func to check creds validity + unit test

* update error msgs

Co-authored-by: Le Tran <le.tran@kasten.io>
  • Loading branch information
leuyentran and Le Tran committed Mar 30, 2022
1 parent 2de85f2 commit 247e548
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 1 deletion.
19 changes: 18 additions & 1 deletion pkg/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ func GetCredentials(ctx context.Context, config map[string]string) (*credentials
if err != nil {
return nil, err
}

// check if role switching is needed, then return creds
return switchAWSRole(ctx, creds, config[ConfigRole], assumedRole, assumeRoleDuration)
}
Expand Down Expand Up @@ -148,3 +147,21 @@ func GetConfig(ctx context.Context, config map[string]string) (awsConfig *aws.Co
}
return &aws.Config{Credentials: creds}, region, nil
}

func IsAwsCredsValid(ctx context.Context, config map[string]string) (bool, error) {
var maxRetries int = 10
awsConfig, region, err := GetConfig(ctx, config)
if err != nil {
return false, errors.Wrap(err, "Failed to get config for AWS creds")
}
s, err := session.NewSession(awsConfig)
if err != nil {
return false, errors.Wrap(err, "Failed to create session with provided creds")
}
stsCli := sts.New(s, aws.NewConfig().WithRegion(region).WithMaxRetries(maxRetries))
_, err = stsCli.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return false, errors.Wrap(err, "Failed to get user with provided creds")
}
return true, nil
}
49 changes: 49 additions & 0 deletions pkg/aws/aws_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2022 The Kanister Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package aws

import (
"context"
"testing"

"gopkg.in/check.v1"

envconfig "github.com/kanisterio/kanister/pkg/config"
)

// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }

type AWSSuite struct{}

var _ = check.Suite(&AWSSuite{})

func (s AWSSuite) TestValidCreds(c *check.C) {
ctx := context.Background()
config := map[string]string{}

config[AccessKeyID] = envconfig.GetEnvOrSkip(c, AccessKeyID)
config[SecretAccessKey] = envconfig.GetEnvOrSkip(c, SecretAccessKey)
config[ConfigRegion] = "us-west-2"

res, err := IsAwsCredsValid(ctx, config)
c.Assert(err, check.IsNil)
c.Assert(res, check.Equals, true)

config[AccessKeyID] = "fake-access-id"
res, err = IsAwsCredsValid(ctx, config)
c.Assert(err, check.NotNil)
c.Assert(res, check.Equals, false)
}
34 changes: 34 additions & 0 deletions pkg/blockstorage/awsebs/awsebs.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,40 @@ func (s *EbsStorage) VolumeCreate(ctx context.Context, volume blockstorage.Volum
return s.VolumeGet(ctx, volID, volume.Az)
}

// CheckVolumeCreate checks if client as permission to create volumes
func (s *EbsStorage) CheckVolumeCreate(ctx context.Context) (bool, error) {
var zoneName *string
var err error
var size int64 = 1
var dryRun bool = true

ec2Cli, err := newEC2Client(*s.config.Region, s.config)
if err != nil {
return false, errors.Wrap(err, "Could not get EC2 client")
}
dai := &ec2.DescribeAvailabilityZonesInput{}
az, err := ec2Cli.DescribeAvailabilityZones(dai)
if err != nil {
return false, errors.New("Fail to get available zone for EC2 client")
}
if az != nil {
zoneName = az.AvailabilityZones[1].ZoneName
} else {
return false, errors.New("No available zone for EC2 client")
}

cvi := &ec2.CreateVolumeInput{
AvailabilityZone: zoneName,
Size: &size,
DryRun: &dryRun,
}
_, err = s.Ec2Cli.CreateVolume(cvi)
if !isDryRunErr(err) {
return false, errors.Wrap(err, "Could not create volume with EC2 client")
}
return true, nil
}

// VolumeGet is part of blockstorage.Provider
func (s *EbsStorage) VolumeGet(ctx context.Context, id string, zone string) (*blockstorage.Volume, error) {
volIDs := []*string{aws.String(id)}
Expand Down

0 comments on commit 247e548

Please sign in to comment.