Skip to content

Commit

Permalink
fix: inject account id to enumerator instead of repo
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Guibert committed Oct 5, 2022
1 parent e0104c8 commit 6335139
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 34 deletions.
4 changes: 2 additions & 2 deletions enumeration/remote/aws/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func Init(version string, alerter alerter.AlerterInterface, providerLibrary *ter
repositoryCache := cache.New(100)

s3Repository := repository.NewS3Repository(client.NewAWSClientFactory(provider.session), repositoryCache)
s3ControlRepository := repository.NewS3ControlRepository(client.NewAWSClientFactory(provider.session), provider.accountId, repositoryCache)
s3ControlRepository := repository.NewS3ControlRepository(client.NewAWSClientFactory(provider.session), repositoryCache)
ec2repository := repository.NewEC2Repository(provider.session, repositoryCache)
elbv2Repository := repository.NewELBV2Repository(provider.session, repositoryCache)
route53repository := repository.NewRoute53Repository(provider.session, repositoryCache)
Expand Down Expand Up @@ -72,7 +72,7 @@ func Init(version string, alerter alerter.AlerterInterface, providerLibrary *ter
remoteLibrary.AddEnumerator(NewS3BucketAnalyticEnumerator(s3Repository, factory, provider.Config, alerter))
remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewS3BucketPublicAccessBlockEnumerator(s3Repository, factory, provider.Config, alerter))
remoteLibrary.AddEnumerator(NewS3AccountPublicAccessBlockEnumerator(s3ControlRepository, factory, provider.Config, alerter))
remoteLibrary.AddEnumerator(NewS3AccountPublicAccessBlockEnumerator(s3ControlRepository, factory, provider.accountId, alerter))

remoteLibrary.AddEnumerator(NewEC2EbsVolumeEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsEbsVolumeResourceType, common.NewGenericDetailsFetcher(aws.AwsEbsVolumeResourceType, provider, deserializer))
Expand Down
14 changes: 0 additions & 14 deletions enumeration/remote/aws/repository/mock_S3ControlRepository.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions enumeration/remote/aws/repository/s3control_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

type S3ControlRepository interface {
DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error)
GetAccountID() string
}

type s3ControlRepository struct {
Expand All @@ -18,10 +17,9 @@ type s3ControlRepository struct {
cache cache.Cache
}

func NewS3ControlRepository(factory client.AwsClientFactoryInterface, accountId string, c cache.Cache) *s3ControlRepository {
func NewS3ControlRepository(factory client.AwsClientFactoryInterface, c cache.Cache) *s3ControlRepository {
return &s3ControlRepository{
clientFactory: factory,
accountId: accountId,
cache: c,
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) {
tt.mocks(mockedClient)
factory := client.MockAwsClientFactoryInterface{}
factory.On("GetS3ControlClient", (*aws.Config)(nil)).Return(mockedClient).Once()
r := NewS3ControlRepository(&factory, "", store)
r := NewS3ControlRepository(&factory, store)
got, err := r.DescribeAccountPublicAccessBlock()
factory.AssertExpectations(t)
assert.Equal(t, tt.wantErr, err)
Expand Down
21 changes: 10 additions & 11 deletions enumeration/remote/aws/s3_account_public_access_block_enumerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,23 @@ import (
"github.com/snyk/driftctl/enumeration/alerter"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
tf "github.com/snyk/driftctl/enumeration/remote/terraform"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)

type S3AccountPublicAccessBlockEnumerator struct {
repository repository.S3ControlRepository
factory resource.ResourceFactory
providerConfig tf.TerraformProviderConfig
alerter alerter.AlerterInterface
repository repository.S3ControlRepository
factory resource.ResourceFactory
accountID string
alerter alerter.AlerterInterface
}

func NewS3AccountPublicAccessBlockEnumerator(repo repository.S3ControlRepository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3AccountPublicAccessBlockEnumerator {
func NewS3AccountPublicAccessBlockEnumerator(repo repository.S3ControlRepository, factory resource.ResourceFactory, accountId string, alerter alerter.AlerterInterface) *S3AccountPublicAccessBlockEnumerator {
return &S3AccountPublicAccessBlockEnumerator{
repository: repo,
factory: factory,
providerConfig: providerConfig,
alerter: alerter,
repository: repo,
factory: factory,
accountID: accountId,
alerter: alerter,
}
}

Expand All @@ -42,7 +41,7 @@ func (e *S3AccountPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
e.repository.GetAccountID(),
e.accountID,
map[string]interface{}{
"block_public_acls": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicAcls),
"block_public_policy": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicPolicy),
Expand Down
6 changes: 3 additions & 3 deletions enumeration/remote/aws_s3_scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,7 @@ func TestS3BucketAnalytic(t *testing.T) {
func TestS3AccountPublicAccessBlock(t *testing.T) {
dummyError := errors.New("this is an error")

accountID := "123456"
tests := []struct {
test string
mocks func(*repository.MockS3ControlRepository, *mocks.AlerterInterface)
Expand All @@ -1080,7 +1081,6 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
{
test: "existing access block",
mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) {
repository.On("GetAccountID").Return("123456")
repository.On("DescribeAccountPublicAccessBlock").Return(&s3control.PublicAccessBlockConfiguration{
BlockPublicAcls: awssdk.Bool(false),
BlockPublicPolicy: awssdk.Bool(true),
Expand All @@ -1090,7 +1090,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
},
assertExpected: func(t *testing.T, got []*resource.Resource) {
assert.Len(t, got, 1)
assert.Equal(t, got[0].ResourceId(), "123456")
assert.Equal(t, got[0].ResourceId(), accountID)
assert.Equal(t, got[0].ResourceType(), resourceaws.AwsS3AccountPublicAccessBlock)
assert.Equal(t, got[0].Attributes(), &resource.Attributes{
"block_public_acls": false,
Expand Down Expand Up @@ -1125,7 +1125,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {

remoteLibrary.AddEnumerator(aws.NewS3AccountPublicAccessBlockEnumerator(
repo, factory,
tf.TerraformProviderConfig{DefaultAlias: "us-east-1"},
accountID,
alerter,
))

Expand Down

0 comments on commit 6335139

Please sign in to comment.