Skip to content

Commit

Permalink
fix(appset): add option to disable SCM providers entirely (#14246) (#…
Browse files Browse the repository at this point in the history
…15248)

* feat(appset): add option to disable SCM providers entirely (#14246)

Signed-off-by: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com>

* clarify docs

Signed-off-by: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com>

* more clarification, small refactor

Signed-off-by: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com>

* more clarification

Signed-off-by: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com>

* fix test

Signed-off-by: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com>

* refactor

Signed-off-by: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com>

* fix test assertion

Signed-off-by: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com>

* simplify test expectation

Signed-off-by: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com>

---------

Signed-off-by: Michael Crenshaw <350466+crenshaw-dev@users.noreply.github.com>
Co-authored-by: Ishita Sequeira <46771830+ishitasequeira@users.noreply.github.com>
  • Loading branch information
crenshaw-dev and ishitasequeira authored Oct 3, 2023
1 parent d90543d commit 43fe01a
Show file tree
Hide file tree
Showing 17 changed files with 220 additions and 70 deletions.
4 changes: 2 additions & 2 deletions applicationset/controllers/requeue_after_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func TestRequeueAfter(t *testing.T) {
"List": generators.NewListGenerator(),
"Clusters": generators.NewClusterGenerator(k8sClient, ctx, appClientset, "argocd"),
"Git": generators.NewGitGenerator(mockServer),
"SCMProvider": generators.NewSCMProviderGenerator(fake.NewClientBuilder().WithObjects(&corev1.Secret{}).Build(), generators.SCMAuthProviders{}, "", []string{""}),
"SCMProvider": generators.NewSCMProviderGenerator(fake.NewClientBuilder().WithObjects(&corev1.Secret{}).Build(), generators.SCMAuthProviders{}, "", []string{""}, true),
"ClusterDecisionResource": generators.NewDuckTypeGenerator(ctx, fakeDynClient, appClientset, "argocd"),
"PullRequest": generators.NewPullRequestGenerator(k8sClient, generators.SCMAuthProviders{}, "", []string{""}),
"PullRequest": generators.NewPullRequestGenerator(k8sClient, generators.SCMAuthProviders{}, "", []string{""}, true),
}

nestedGenerators := map[string]generators.Generator{
Expand Down
25 changes: 11 additions & 14 deletions applicationset/generators/pull_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ type PullRequestGenerator struct {
auth SCMAuthProviders
scmRootCAPath string
allowedSCMProviders []string
enableSCMProviders bool
}

func NewPullRequestGenerator(client client.Client, auth SCMAuthProviders, scmRootCAPath string, allowedScmProviders []string) Generator {
func NewPullRequestGenerator(client client.Client, auth SCMAuthProviders, scmRootCAPath string, allowedScmProviders []string, enableSCMProviders bool) Generator {
g := &PullRequestGenerator{
client: client,
auth: auth,
scmRootCAPath: scmRootCAPath,
allowedSCMProviders: allowedScmProviders,
enableSCMProviders: enableSCMProviders,
}
g.selectServiceProviderFunc = g.selectServiceProvider
return g
Expand Down Expand Up @@ -66,7 +68,7 @@ func (g *PullRequestGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
ctx := context.Background()
svc, err := g.selectServiceProviderFunc(ctx, appSetGenerator.PullRequest, applicationSetInfo)
if err != nil {
return nil, fmt.Errorf("failed to select pull request service provider: %v", err)
return nil, fmt.Errorf("failed to select pull request service provider: %w", err)
}

pulls, err := pullrequest.ListPullRequests(ctx, svc, appSetGenerator.PullRequest.Filters)
Expand Down Expand Up @@ -121,17 +123,18 @@ func (g *PullRequestGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha

// selectServiceProvider selects the provider to get pull requests from the configuration
func (g *PullRequestGenerator) selectServiceProvider(ctx context.Context, generatorConfig *argoprojiov1alpha1.PullRequestGenerator, applicationSetInfo *argoprojiov1alpha1.ApplicationSet) (pullrequest.PullRequestService, error) {
if !g.enableSCMProviders {
return nil, ErrSCMProvidersDisabled
}
if err := ScmProviderAllowed(applicationSetInfo, generatorConfig, g.allowedSCMProviders); err != nil {
return nil, fmt.Errorf("scm provider not allowed: %w", err)
}

if generatorConfig.Github != nil {
if !ScmProviderAllowed(applicationSetInfo, generatorConfig.Github.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", generatorConfig.Github.API)
}
return g.github(ctx, generatorConfig.Github, applicationSetInfo)
}
if generatorConfig.GitLab != nil {
providerConfig := generatorConfig.GitLab
if !ScmProviderAllowed(applicationSetInfo, providerConfig.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.API)
}
token, err := g.getSecretRef(ctx, providerConfig.TokenRef, applicationSetInfo.Namespace)
if err != nil {
return nil, fmt.Errorf("error fetching Secret token: %v", err)
Expand All @@ -140,9 +143,6 @@ func (g *PullRequestGenerator) selectServiceProvider(ctx context.Context, genera
}
if generatorConfig.Gitea != nil {
providerConfig := generatorConfig.Gitea
if !ScmProviderAllowed(applicationSetInfo, providerConfig.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", generatorConfig.Gitea.API)
}
token, err := g.getSecretRef(ctx, providerConfig.TokenRef, applicationSetInfo.Namespace)
if err != nil {
return nil, fmt.Errorf("error fetching Secret token: %v", err)
Expand All @@ -151,9 +151,6 @@ func (g *PullRequestGenerator) selectServiceProvider(ctx context.Context, genera
}
if generatorConfig.BitbucketServer != nil {
providerConfig := generatorConfig.BitbucketServer
if !ScmProviderAllowed(applicationSetInfo, providerConfig.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.API)
}
if providerConfig.BasicAuth != nil {
password, err := g.getSecretRef(ctx, providerConfig.BasicAuth.PasswordRef, applicationSetInfo.Namespace)
if err != nil {
Expand Down
36 changes: 29 additions & 7 deletions applicationset/generators/pull_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
cases := []struct {
name string
providerConfig *argoprojiov1alpha1.PullRequestGenerator
expectedError string
expectedError error
}{
{
name: "Error Github",
Expand All @@ -287,7 +287,7 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
API: "https://myservice.mynamespace.svc.cluster.local",
},
},
expectedError: "failed to select pull request service provider: scm provider not allowed: https://myservice.mynamespace.svc.cluster.local",
expectedError: &ErrDisallowedSCMProvider{},
},
{
name: "Error Gitlab",
Expand All @@ -296,7 +296,7 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
API: "https://myservice.mynamespace.svc.cluster.local",
},
},
expectedError: "failed to select pull request service provider: scm provider not allowed: https://myservice.mynamespace.svc.cluster.local",
expectedError: &ErrDisallowedSCMProvider{},
},
{
name: "Error Gitea",
Expand All @@ -305,7 +305,7 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
API: "https://myservice.mynamespace.svc.cluster.local",
},
},
expectedError: "failed to select pull request service provider: scm provider not allowed: https://myservice.mynamespace.svc.cluster.local",
expectedError: &ErrDisallowedSCMProvider{},
},
{
name: "Error Bitbucket",
Expand All @@ -314,7 +314,7 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
API: "https://myservice.mynamespace.svc.cluster.local",
},
},
expectedError: "failed to select pull request service provider: scm provider not allowed: https://myservice.mynamespace.svc.cluster.local",
expectedError: &ErrDisallowedSCMProvider{},
},
}

Expand All @@ -330,7 +330,7 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
"gitea.myorg.com",
"bitbucket.myorg.com",
"azuredevops.myorg.com",
})
}, true)

applicationSetInfo := argoprojiov1alpha1.ApplicationSet{
ObjectMeta: metav1.ObjectMeta{
Expand All @@ -346,7 +346,29 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
_, err := pullRequestGenerator.GenerateParams(&applicationSetInfo.Spec.Generators[0], &applicationSetInfo)

assert.Error(t, err, "Must return an error")
assert.Equal(t, testCaseCopy.expectedError, err.Error())
assert.ErrorAs(t, err, testCaseCopy.expectedError)
})
}
}

func TestSCMProviderDisabled_PRGenerator(t *testing.T) {
generator := NewPullRequestGenerator(nil, SCMAuthProviders{}, "", []string{}, false)

applicationSetInfo := argoprojiov1alpha1.ApplicationSet{
ObjectMeta: metav1.ObjectMeta{
Name: "set",
},
Spec: argoprojiov1alpha1.ApplicationSetSpec{
Generators: []argoprojiov1alpha1.ApplicationSetGenerator{{
PullRequest: &argoprojiov1alpha1.PullRequestGenerator{
Github: &argoprojiov1alpha1.PullRequestGeneratorGithub{
API: "https://myservice.mynamespace.svc.cluster.local",
},
},
}},
},
}

_, err := generator.GenerateParams(&applicationSetInfo.Spec.Generators[0], &applicationSetInfo)
assert.ErrorIs(t, err, ErrSCMProvidersDisabled)
}
62 changes: 39 additions & 23 deletions applicationset/generators/scm_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package generators

import (
"context"
"errors"
"fmt"
"strings"
"time"
Expand Down Expand Up @@ -31,24 +32,26 @@ type SCMProviderGenerator struct {
SCMAuthProviders
scmRootCAPath string
allowedSCMProviders []string
enableSCMProviders bool
}

type SCMAuthProviders struct {
GitHubApps github_app_auth.Credentials
}

func NewSCMProviderGenerator(client client.Client, providers SCMAuthProviders, scmRootCAPath string, allowedSCMProviders []string) Generator {
func NewSCMProviderGenerator(client client.Client, providers SCMAuthProviders, scmRootCAPath string, allowedSCMProviders []string, enableSCMProviders bool) Generator {
return &SCMProviderGenerator{
client: client,
SCMAuthProviders: providers,
scmRootCAPath: scmRootCAPath,
allowedSCMProviders: allowedSCMProviders,
enableSCMProviders: enableSCMProviders,
}
}

// Testing generator
func NewTestSCMProviderGenerator(overrideProvider scm_provider.SCMProviderService) Generator {
return &SCMProviderGenerator{overrideProvider: overrideProvider}
return &SCMProviderGenerator{overrideProvider: overrideProvider, enableSCMProviders: true}
}

func (g *SCMProviderGenerator) GetRequeueAfter(appSetGenerator *argoprojiov1alpha1.ApplicationSetGenerator) time.Duration {
Expand All @@ -65,24 +68,44 @@ func (g *SCMProviderGenerator) GetTemplate(appSetGenerator *argoprojiov1alpha1.A
return &appSetGenerator.SCMProvider.Template
}

func ScmProviderAllowed(applicationSetInfo *argoprojiov1alpha1.ApplicationSet, url string, allowedScmProviders []string) bool {
var ErrSCMProvidersDisabled = errors.New("scm providers are disabled")

type ErrDisallowedSCMProvider struct {
Provider string
Allowed []string
}

func NewErrDisallowedSCMProvider(provider string, allowed []string) ErrDisallowedSCMProvider {
return ErrDisallowedSCMProvider{
Provider: provider,
Allowed: allowed,
}
}

func (e ErrDisallowedSCMProvider) Error() string {
return fmt.Sprintf("scm provider %q not allowed, must use one of the following: %s", e.Provider, strings.Join(e.Allowed, ", "))
}

func ScmProviderAllowed(applicationSetInfo *argoprojiov1alpha1.ApplicationSet, generator SCMGeneratorWithCustomApiUrl, allowedScmProviders []string) error {
url := generator.CustomApiUrl()

if url == "" || len(allowedScmProviders) == 0 {
return true
return nil
}

for _, allowedScmProvider := range allowedScmProviders {
if url == allowedScmProvider {
return true
return nil
}
}

log.WithFields(log.Fields{
common.SecurityField: common.SecurityMedium,
"applicationset": applicationSetInfo.Name,
"appSetNamespace": applicationSetInfo.Namespace,
}).Debugf("attempted to use disallowed SCM %q", url)
}).Debugf("attempted to use disallowed SCM %q, must use one of the following: %s", url, strings.Join(allowedScmProviders, ", "))

return false
return NewErrDisallowedSCMProvider(url, allowedScmProviders)
}

func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha1.ApplicationSetGenerator, applicationSetInfo *argoprojiov1alpha1.ApplicationSet) ([]map[string]interface{}, error) {
Expand All @@ -94,26 +117,28 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
return nil, EmptyAppSetGeneratorError
}

ctx := context.Background()
if !g.enableSCMProviders {
return nil, ErrSCMProvidersDisabled
}

// Create the SCM provider helper.
providerConfig := appSetGenerator.SCMProvider

if err := ScmProviderAllowed(applicationSetInfo, providerConfig, g.allowedSCMProviders); err != nil {
return nil, fmt.Errorf("scm provider not allowed: %w", err)
}

ctx := context.Background()
var provider scm_provider.SCMProviderService
if g.overrideProvider != nil {
provider = g.overrideProvider
} else if providerConfig.Github != nil {
if !ScmProviderAllowed(applicationSetInfo, providerConfig.Github.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.Github.API)
}
var err error
provider, err = g.githubProvider(ctx, providerConfig.Github, applicationSetInfo)
if err != nil {
return nil, fmt.Errorf("scm provider: %w", err)
}
} else if providerConfig.Gitlab != nil {
if !ScmProviderAllowed(applicationSetInfo, providerConfig.Gitlab.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.Gitlab.API)
}
token, err := g.getSecretRef(ctx, providerConfig.Gitlab.TokenRef, applicationSetInfo.Namespace)
if err != nil {
return nil, fmt.Errorf("error fetching Gitlab token: %v", err)
Expand All @@ -123,9 +148,6 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
return nil, fmt.Errorf("error initializing Gitlab service: %v", err)
}
} else if providerConfig.Gitea != nil {
if !ScmProviderAllowed(applicationSetInfo, providerConfig.Gitea.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.Gitea.API)
}
token, err := g.getSecretRef(ctx, providerConfig.Gitea.TokenRef, applicationSetInfo.Namespace)
if err != nil {
return nil, fmt.Errorf("error fetching Gitea token: %v", err)
Expand All @@ -136,9 +158,6 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
}
} else if providerConfig.BitbucketServer != nil {
providerConfig := providerConfig.BitbucketServer
if !ScmProviderAllowed(applicationSetInfo, providerConfig.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.API)
}
var scmError error
if providerConfig.BasicAuth != nil {
password, err := g.getSecretRef(ctx, providerConfig.BasicAuth.PasswordRef, applicationSetInfo.Namespace)
Expand All @@ -153,9 +172,6 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
return nil, fmt.Errorf("error initializing Bitbucket Server service: %v", scmError)
}
} else if providerConfig.AzureDevOps != nil {
if !ScmProviderAllowed(applicationSetInfo, providerConfig.AzureDevOps.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.AzureDevOps.API)
}
token, err := g.getSecretRef(ctx, providerConfig.AzureDevOps.AccessTokenRef, applicationSetInfo.Namespace)
if err != nil {
return nil, fmt.Errorf("error fetching Azure Devops access token: %v", err)
Expand Down
Loading

0 comments on commit 43fe01a

Please sign in to comment.