diff --git a/pkg/apis/application/v1alpha1/repository_types.go b/pkg/apis/application/v1alpha1/repository_types.go index 862627f216f0a..21e43286292df 100644 --- a/pkg/apis/application/v1alpha1/repository_types.go +++ b/pkg/apis/application/v1alpha1/repository_types.go @@ -195,17 +195,27 @@ func (repo *Repository) GetHelmCreds() helm.Creds { } func getCAPath(repoURL string) string { - if git.IsHTTPSURL(repoURL) { - if parsedURL, err := url.Parse(repoURL); err == nil { - if caPath, err := cert.GetCertBundlePathForRepository(parsedURL.Host); err == nil { - return caPath - } else { - log.Warnf("Could not get cert bundle path for host '%s'", parsedURL.Host) - } + hostname := "" + + // url.Parse() will happily parse most things thrown at it. When the URL + // is either https or oci, we use the parsed hostname to receive the cert, + // otherwise we'll use the parsed path (OCI repos are often specified as + // hostname, without protocol). + if parsedURL, err := url.Parse(repoURL); err == nil { + if parsedURL.Scheme == "https" || parsedURL.Scheme == "oci" { + hostname = parsedURL.Host + } else if parsedURL.Scheme == "" { + hostname = parsedURL.Path + } + } else { + log.Warnf("Could not parse repo URL '%s': %v", repoURL, err) + } + + if hostname != "" { + if caPath, err := cert.GetCertBundlePathForRepository(hostname); err == nil { + return caPath } else { - // We don't fail if we cannot parse the URL, but log a warning in that - // case. And we execute the command in a verbatim way. - log.Warnf("Could not parse repo URL '%s'", repoURL) + log.Warnf("Could not get cert bundle path for repository '%s': %v", repoURL, err) } } return "" diff --git a/pkg/apis/application/v1alpha1/types_test.go b/pkg/apis/application/v1alpha1/types_test.go index bc46c130cb188..40d66c0ab2719 100644 --- a/pkg/apis/application/v1alpha1/types_test.go +++ b/pkg/apis/application/v1alpha1/types_test.go @@ -2,10 +2,14 @@ package v1alpha1 import ( fmt "fmt" + "io/ioutil" + "os" + "path" "reflect" "testing" "time" + argocdcommon "github.com/argoproj/argo-cd/v2/common" "k8s.io/utils/pointer" "github.com/argoproj/gitops-engine/pkg/sync/common" @@ -2592,3 +2596,45 @@ func Test_validateGroupName(t *testing.T) { }) } } + +func TestGetCAPath(t *testing.T) { + + temppath, err := ioutil.TempDir("", "argocd-cert-test") + if err != nil { + panic(err) + } + cert, err := ioutil.ReadFile("../../../../test/fixture/certs/argocd-test-server.crt") + if err != nil { + panic(err) + } + err = ioutil.WriteFile(path.Join(temppath, "foo.example.com"), cert, 0666) + if err != nil { + panic(err) + } + defer os.RemoveAll(temppath) + os.Setenv(argocdcommon.EnvVarTLSDataPath, temppath) + validcert := []string{ + "https://foo.example.com", + "oci://foo.example.com", + "foo.example.com", + } + invalidpath := []string{ + "https://bar.example.com", + "oci://bar.example.com", + "bar.example.com", + "ssh://foo.example.com", + "/some/invalid/thing", + "../another/invalid/thing", + "./also/invalid", + "$invalid/as/well", + } + + for _, str := range validcert { + path := getCAPath(str) + assert.NotEmpty(t, path) + } + for _, str := range invalidpath { + path := getCAPath(str) + assert.Empty(t, path) + } +} diff --git a/util/cert/cert.go b/util/cert/cert.go index 2e06ec4f2a88d..e146eb2a59464 100644 --- a/util/cert/cert.go +++ b/util/cert/cert.go @@ -112,11 +112,11 @@ func GetSSHKnownHostsDataPath() string { func DecodePEMCertificateToX509(pemData string) (*x509.Certificate, error) { decodedData, _ := pem.Decode([]byte(pemData)) if decodedData == nil { - return nil, errors.New("Could not decode PEM data from input.") + return nil, errors.New("could not decode PEM data from input") } x509Cert, err := x509.ParseCertificate(decodedData.Bytes) if err != nil { - return nil, errors.New("Could not parse X509 data from input.") + return nil, errors.New("could not parse X509 data from input") } return x509Cert, nil } @@ -171,7 +171,7 @@ func ParseTLSCertificatesFromStream(stream io.Reader) ([]string, error) { } if certLine > CertificateMaxLines { - return nil, errors.New("Maximum number of lines exceeded during certificate parsing.") + return nil, errors.New("maximum number of lines exceeded during certificate parsing") } } @@ -233,7 +233,7 @@ func IsValidSSHKnownHostsEntry(line string) bool { func TokenizeSSHKnownHostsEntry(knownHostsEntry string) (string, string, []byte, error) { knownHostsToken := strings.SplitN(knownHostsEntry, " ", 3) if len(knownHostsToken) != 3 { - return "", "", nil, fmt.Errorf("Error while tokenizing input data.") + return "", "", nil, fmt.Errorf("error while tokenizing input data") } return knownHostsToken[0], knownHostsToken[1], []byte(knownHostsToken[2]), nil } @@ -301,7 +301,17 @@ func ServerNameWithoutPort(serverName string) string { // Load certificate data from a file. If the file does not exist, we do not // consider it an error and just return empty data. func GetCertificateForConnect(serverName string) ([]string, error) { - certPath := fmt.Sprintf("%s/%s", GetTLSCertificateDataPath(), ServerNameWithoutPort(serverName)) + dataPath := GetTLSCertificateDataPath() + if !strings.HasSuffix(dataPath, "/") { + dataPath += "/" + } + certPath, err := filepath.Abs(filepath.Join(dataPath, ServerNameWithoutPort(serverName))) + if err != nil { + return nil, err + } + if !strings.HasPrefix(certPath, dataPath) { + return nil, fmt.Errorf("could not get certificate for host %s", serverName) + } certificates, err := ParseTLSCertificatesFromPath(certPath) if err != nil { if os.IsNotExist(err) { @@ -312,7 +322,7 @@ func GetCertificateForConnect(serverName string) ([]string, error) { } if len(certificates) == 0 { - return nil, fmt.Errorf("No certificates found in existing file.") + return nil, fmt.Errorf("no certificates found in existing file") } return certificates, nil diff --git a/util/cert/cert_test.go b/util/cert/cert_test.go index 6ac839f2efac5..75459d9f7d35f 100644 --- a/util/cert/cert_test.go +++ b/util/cert/cert_test.go @@ -517,7 +517,7 @@ func TestGetCertificateForConnect(t *testing.T) { certs, err := GetCertificateForConnect("127.0.0.1") assert.Error(t, err) assert.Len(t, certs, 0) - assert.Contains(t, err.Error(), "No certificates found") + assert.Contains(t, err.Error(), "no certificates found") }) }