Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(appset): add option to disable SCM providers entirely (#14246) #15248

Merged
merged 11 commits into from
Oct 3, 2023
4 changes: 2 additions & 2 deletions applicationset/controllers/requeue_after_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func TestRequeueAfter(t *testing.T) {
"List": generators.NewListGenerator(),
"Clusters": generators.NewClusterGenerator(k8sClient, ctx, appClientset, "argocd"),
"Git": generators.NewGitGenerator(mockServer),
"SCMProvider": generators.NewSCMProviderGenerator(fake.NewClientBuilder().WithObjects(&corev1.Secret{}).Build(), generators.SCMAuthProviders{}, "", []string{""}),
"SCMProvider": generators.NewSCMProviderGenerator(fake.NewClientBuilder().WithObjects(&corev1.Secret{}).Build(), generators.SCMAuthProviders{}, "", []string{""}, true),
"ClusterDecisionResource": generators.NewDuckTypeGenerator(ctx, fakeDynClient, appClientset, "argocd"),
"PullRequest": generators.NewPullRequestGenerator(k8sClient, generators.SCMAuthProviders{}, "", []string{""}),
"PullRequest": generators.NewPullRequestGenerator(k8sClient, generators.SCMAuthProviders{}, "", []string{""}, true),
}

nestedGenerators := map[string]generators.Generator{
Expand Down
25 changes: 11 additions & 14 deletions applicationset/generators/pull_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ type PullRequestGenerator struct {
auth SCMAuthProviders
scmRootCAPath string
allowedSCMProviders []string
enableSCMProviders bool
}

func NewPullRequestGenerator(client client.Client, auth SCMAuthProviders, scmRootCAPath string, allowedScmProviders []string) Generator {
func NewPullRequestGenerator(client client.Client, auth SCMAuthProviders, scmRootCAPath string, allowedScmProviders []string, enableSCMProviders bool) Generator {
g := &PullRequestGenerator{
client: client,
auth: auth,
scmRootCAPath: scmRootCAPath,
allowedSCMProviders: allowedScmProviders,
enableSCMProviders: enableSCMProviders,
}
g.selectServiceProviderFunc = g.selectServiceProvider
return g
Expand Down Expand Up @@ -66,7 +68,7 @@ func (g *PullRequestGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
ctx := context.Background()
svc, err := g.selectServiceProviderFunc(ctx, appSetGenerator.PullRequest, applicationSetInfo)
if err != nil {
return nil, fmt.Errorf("failed to select pull request service provider: %v", err)
return nil, fmt.Errorf("failed to select pull request service provider: %w", err)
}

pulls, err := pullrequest.ListPullRequests(ctx, svc, appSetGenerator.PullRequest.Filters)
Expand Down Expand Up @@ -121,17 +123,18 @@ func (g *PullRequestGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha

// selectServiceProvider selects the provider to get pull requests from the configuration
func (g *PullRequestGenerator) selectServiceProvider(ctx context.Context, generatorConfig *argoprojiov1alpha1.PullRequestGenerator, applicationSetInfo *argoprojiov1alpha1.ApplicationSet) (pullrequest.PullRequestService, error) {
if !g.enableSCMProviders {
return nil, ErrSCMProvidersDisabled
}
if err := ScmProviderAllowed(applicationSetInfo, generatorConfig, g.allowedSCMProviders); err != nil {
return nil, fmt.Errorf("scm provider not allowed: %w", err)
}

if generatorConfig.Github != nil {
if !ScmProviderAllowed(applicationSetInfo, generatorConfig.Github.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", generatorConfig.Github.API)
}
return g.github(ctx, generatorConfig.Github, applicationSetInfo)
}
if generatorConfig.GitLab != nil {
providerConfig := generatorConfig.GitLab
if !ScmProviderAllowed(applicationSetInfo, providerConfig.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.API)
}
token, err := g.getSecretRef(ctx, providerConfig.TokenRef, applicationSetInfo.Namespace)
if err != nil {
return nil, fmt.Errorf("error fetching Secret token: %v", err)
Expand All @@ -140,9 +143,6 @@ func (g *PullRequestGenerator) selectServiceProvider(ctx context.Context, genera
}
if generatorConfig.Gitea != nil {
providerConfig := generatorConfig.Gitea
if !ScmProviderAllowed(applicationSetInfo, providerConfig.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", generatorConfig.Gitea.API)
}
token, err := g.getSecretRef(ctx, providerConfig.TokenRef, applicationSetInfo.Namespace)
if err != nil {
return nil, fmt.Errorf("error fetching Secret token: %v", err)
Expand All @@ -151,9 +151,6 @@ func (g *PullRequestGenerator) selectServiceProvider(ctx context.Context, genera
}
if generatorConfig.BitbucketServer != nil {
providerConfig := generatorConfig.BitbucketServer
if !ScmProviderAllowed(applicationSetInfo, providerConfig.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.API)
}
if providerConfig.BasicAuth != nil {
password, err := g.getSecretRef(ctx, providerConfig.BasicAuth.PasswordRef, applicationSetInfo.Namespace)
if err != nil {
Expand Down
36 changes: 29 additions & 7 deletions applicationset/generators/pull_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
cases := []struct {
name string
providerConfig *argoprojiov1alpha1.PullRequestGenerator
expectedError string
expectedError error
}{
{
name: "Error Github",
Expand All @@ -287,7 +287,7 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
API: "https://myservice.mynamespace.svc.cluster.local",
},
},
expectedError: "failed to select pull request service provider: scm provider not allowed: https://myservice.mynamespace.svc.cluster.local",
expectedError: &ErrDisallowedSCMProvider{},
},
{
name: "Error Gitlab",
Expand All @@ -296,7 +296,7 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
API: "https://myservice.mynamespace.svc.cluster.local",
},
},
expectedError: "failed to select pull request service provider: scm provider not allowed: https://myservice.mynamespace.svc.cluster.local",
expectedError: &ErrDisallowedSCMProvider{},
},
{
name: "Error Gitea",
Expand All @@ -305,7 +305,7 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
API: "https://myservice.mynamespace.svc.cluster.local",
},
},
expectedError: "failed to select pull request service provider: scm provider not allowed: https://myservice.mynamespace.svc.cluster.local",
expectedError: &ErrDisallowedSCMProvider{},
},
{
name: "Error Bitbucket",
Expand All @@ -314,7 +314,7 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
API: "https://myservice.mynamespace.svc.cluster.local",
},
},
expectedError: "failed to select pull request service provider: scm provider not allowed: https://myservice.mynamespace.svc.cluster.local",
expectedError: &ErrDisallowedSCMProvider{},
},
}

Expand All @@ -330,7 +330,7 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
"gitea.myorg.com",
"bitbucket.myorg.com",
"azuredevops.myorg.com",
})
}, true)

applicationSetInfo := argoprojiov1alpha1.ApplicationSet{
ObjectMeta: metav1.ObjectMeta{
Expand All @@ -346,7 +346,29 @@ func TestAllowedSCMProviderPullRequest(t *testing.T) {
_, err := pullRequestGenerator.GenerateParams(&applicationSetInfo.Spec.Generators[0], &applicationSetInfo)

assert.Error(t, err, "Must return an error")
assert.Equal(t, testCaseCopy.expectedError, err.Error())
assert.ErrorAs(t, err, testCaseCopy.expectedError)
})
}
}

func TestSCMProviderDisabled_PRGenerator(t *testing.T) {
generator := NewPullRequestGenerator(nil, SCMAuthProviders{}, "", []string{}, false)

applicationSetInfo := argoprojiov1alpha1.ApplicationSet{
ObjectMeta: metav1.ObjectMeta{
Name: "set",
},
Spec: argoprojiov1alpha1.ApplicationSetSpec{
Generators: []argoprojiov1alpha1.ApplicationSetGenerator{{
PullRequest: &argoprojiov1alpha1.PullRequestGenerator{
Github: &argoprojiov1alpha1.PullRequestGeneratorGithub{
API: "https://myservice.mynamespace.svc.cluster.local",
},
},
}},
},
}

_, err := generator.GenerateParams(&applicationSetInfo.Spec.Generators[0], &applicationSetInfo)
assert.ErrorIs(t, err, ErrSCMProvidersDisabled)
}
62 changes: 39 additions & 23 deletions applicationset/generators/scm_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package generators

import (
"context"
"errors"
"fmt"
"strings"
"time"
Expand Down Expand Up @@ -31,24 +32,26 @@ type SCMProviderGenerator struct {
SCMAuthProviders
scmRootCAPath string
allowedSCMProviders []string
enableSCMProviders bool
}

type SCMAuthProviders struct {
GitHubApps github_app_auth.Credentials
}

func NewSCMProviderGenerator(client client.Client, providers SCMAuthProviders, scmRootCAPath string, allowedSCMProviders []string) Generator {
func NewSCMProviderGenerator(client client.Client, providers SCMAuthProviders, scmRootCAPath string, allowedSCMProviders []string, enableSCMProviders bool) Generator {
return &SCMProviderGenerator{
client: client,
SCMAuthProviders: providers,
scmRootCAPath: scmRootCAPath,
allowedSCMProviders: allowedSCMProviders,
enableSCMProviders: enableSCMProviders,
}
}

// Testing generator
func NewTestSCMProviderGenerator(overrideProvider scm_provider.SCMProviderService) Generator {
return &SCMProviderGenerator{overrideProvider: overrideProvider}
return &SCMProviderGenerator{overrideProvider: overrideProvider, enableSCMProviders: true}
}

func (g *SCMProviderGenerator) GetRequeueAfter(appSetGenerator *argoprojiov1alpha1.ApplicationSetGenerator) time.Duration {
Expand All @@ -65,24 +68,44 @@ func (g *SCMProviderGenerator) GetTemplate(appSetGenerator *argoprojiov1alpha1.A
return &appSetGenerator.SCMProvider.Template
}

func ScmProviderAllowed(applicationSetInfo *argoprojiov1alpha1.ApplicationSet, url string, allowedScmProviders []string) bool {
var ErrSCMProvidersDisabled = errors.New("scm providers are disabled")

type ErrDisallowedSCMProvider struct {
Provider string
Allowed []string
}

func NewErrDisallowedSCMProvider(provider string, allowed []string) ErrDisallowedSCMProvider {
return ErrDisallowedSCMProvider{
Provider: provider,
Allowed: allowed,
}
}

func (e ErrDisallowedSCMProvider) Error() string {
return fmt.Sprintf("scm provider %q not allowed, must use one of the following: %s", e.Provider, strings.Join(e.Allowed, ", "))
}

func ScmProviderAllowed(applicationSetInfo *argoprojiov1alpha1.ApplicationSet, generator SCMGeneratorWithCustomApiUrl, allowedScmProviders []string) error {
url := generator.CustomApiUrl()

if url == "" || len(allowedScmProviders) == 0 {
return true
return nil
}

for _, allowedScmProvider := range allowedScmProviders {
if url == allowedScmProvider {
return true
return nil
}
}

log.WithFields(log.Fields{
common.SecurityField: common.SecurityMedium,
"applicationset": applicationSetInfo.Name,
"appSetNamespace": applicationSetInfo.Namespace,
}).Debugf("attempted to use disallowed SCM %q", url)
}).Debugf("attempted to use disallowed SCM %q, must use one of the following: %s", url, strings.Join(allowedScmProviders, ", "))

return false
return NewErrDisallowedSCMProvider(url, allowedScmProviders)
}

func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha1.ApplicationSetGenerator, applicationSetInfo *argoprojiov1alpha1.ApplicationSet) ([]map[string]interface{}, error) {
Expand All @@ -94,26 +117,28 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
return nil, EmptyAppSetGeneratorError
}

ctx := context.Background()
if !g.enableSCMProviders {
return nil, ErrSCMProvidersDisabled
}

// Create the SCM provider helper.
providerConfig := appSetGenerator.SCMProvider

if err := ScmProviderAllowed(applicationSetInfo, providerConfig, g.allowedSCMProviders); err != nil {
return nil, fmt.Errorf("scm provider not allowed: %w", err)
}

ctx := context.Background()
var provider scm_provider.SCMProviderService
if g.overrideProvider != nil {
provider = g.overrideProvider
} else if providerConfig.Github != nil {
if !ScmProviderAllowed(applicationSetInfo, providerConfig.Github.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.Github.API)
}
var err error
provider, err = g.githubProvider(ctx, providerConfig.Github, applicationSetInfo)
if err != nil {
return nil, fmt.Errorf("scm provider: %w", err)
}
} else if providerConfig.Gitlab != nil {
if !ScmProviderAllowed(applicationSetInfo, providerConfig.Gitlab.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.Gitlab.API)
}
token, err := g.getSecretRef(ctx, providerConfig.Gitlab.TokenRef, applicationSetInfo.Namespace)
if err != nil {
return nil, fmt.Errorf("error fetching Gitlab token: %v", err)
Expand All @@ -123,9 +148,6 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
return nil, fmt.Errorf("error initializing Gitlab service: %v", err)
}
} else if providerConfig.Gitea != nil {
if !ScmProviderAllowed(applicationSetInfo, providerConfig.Gitea.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.Gitea.API)
}
token, err := g.getSecretRef(ctx, providerConfig.Gitea.TokenRef, applicationSetInfo.Namespace)
if err != nil {
return nil, fmt.Errorf("error fetching Gitea token: %v", err)
Expand All @@ -136,9 +158,6 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
}
} else if providerConfig.BitbucketServer != nil {
providerConfig := providerConfig.BitbucketServer
if !ScmProviderAllowed(applicationSetInfo, providerConfig.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.API)
}
var scmError error
if providerConfig.BasicAuth != nil {
password, err := g.getSecretRef(ctx, providerConfig.BasicAuth.PasswordRef, applicationSetInfo.Namespace)
Expand All @@ -153,9 +172,6 @@ func (g *SCMProviderGenerator) GenerateParams(appSetGenerator *argoprojiov1alpha
return nil, fmt.Errorf("error initializing Bitbucket Server service: %v", scmError)
}
} else if providerConfig.AzureDevOps != nil {
if !ScmProviderAllowed(applicationSetInfo, providerConfig.AzureDevOps.API, g.allowedSCMProviders) {
return nil, fmt.Errorf("scm provider not allowed: %s", providerConfig.AzureDevOps.API)
}
token, err := g.getSecretRef(ctx, providerConfig.AzureDevOps.AccessTokenRef, applicationSetInfo.Namespace)
if err != nil {
return nil, fmt.Errorf("error fetching Azure Devops access token: %v", err)
Expand Down
Loading