Skip to content

Commit

Permalink
fix: Helm OCI repositories with custom CAs
Browse files Browse the repository at this point in the history
Signed-off-by: jannfis <jann@mistrust.net>
  • Loading branch information
jannfis committed Apr 7, 2022
1 parent 5921e2a commit a3583cb
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 17 deletions.
30 changes: 20 additions & 10 deletions pkg/apis/application/v1alpha1/repository_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,27 @@ func (repo *Repository) GetHelmCreds() helm.Creds {
}

func getCAPath(repoURL string) string {
if git.IsHTTPSURL(repoURL) {
if parsedURL, err := url.Parse(repoURL); err == nil {
if caPath, err := cert.GetCertBundlePathForRepository(parsedURL.Host); err == nil {
return caPath
} else {
log.Warnf("Could not get cert bundle path for host '%s'", parsedURL.Host)
}
hostname := ""

// url.Parse() will happily parse most things thrown at it. When the URL
// is either https or oci, we use the parsed hostname to receive the cert,
// otherwise we'll use the parsed path (OCI repos are often specified as
// hostname, without protocol).
if parsedURL, err := url.Parse(repoURL); err == nil {
if parsedURL.Scheme == "https" || parsedURL.Scheme == "oci" {
hostname = parsedURL.Host
} else if parsedURL.Scheme == "" {
hostname = parsedURL.Path
}
} else {
log.Warnf("Could not parse repo URL '%s': %v", repoURL, err)
}

if hostname != "" {
if caPath, err := cert.GetCertBundlePathForRepository(hostname); err == nil {
return caPath
} else {
// We don't fail if we cannot parse the URL, but log a warning in that
// case. And we execute the command in a verbatim way.
log.Warnf("Could not parse repo URL '%s'", repoURL)
log.Warnf("Could not get cert bundle path for repository '%s': %v", repoURL, err)
}
}
return ""
Expand Down
46 changes: 46 additions & 0 deletions pkg/apis/application/v1alpha1/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ package v1alpha1

import (
fmt "fmt"
"io/ioutil"
"os"
"path"
"reflect"
"testing"
"time"

argocdcommon "github.com/argoproj/argo-cd/v2/common"
"k8s.io/utils/pointer"

"github.com/argoproj/gitops-engine/pkg/sync/common"
Expand Down Expand Up @@ -2592,3 +2596,45 @@ func Test_validateGroupName(t *testing.T) {
})
}
}

func TestGetCAPath(t *testing.T) {

temppath, err := ioutil.TempDir("", "argocd-cert-test")
if err != nil {
panic(err)
}
cert, err := ioutil.ReadFile("../../../../test/fixture/certs/argocd-test-server.crt")
if err != nil {
panic(err)
}
err = ioutil.WriteFile(path.Join(temppath, "foo.example.com"), cert, 0666)
if err != nil {
panic(err)
}
defer os.RemoveAll(temppath)
os.Setenv(argocdcommon.EnvVarTLSDataPath, temppath)
validcert := []string{
"https://foo.example.com",
"oci://foo.example.com",
"foo.example.com",
}
invalidpath := []string{
"https://bar.example.com",
"oci://bar.example.com",
"bar.example.com",
"ssh://foo.example.com",
"/some/invalid/thing",
"../another/invalid/thing",
"./also/invalid",
"$invalid/as/well",
}

for _, str := range validcert {
path := getCAPath(str)
assert.NotEmpty(t, path)
}
for _, str := range invalidpath {
path := getCAPath(str)
assert.Empty(t, path)
}
}
22 changes: 16 additions & 6 deletions util/cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ func GetSSHKnownHostsDataPath() string {
func DecodePEMCertificateToX509(pemData string) (*x509.Certificate, error) {
decodedData, _ := pem.Decode([]byte(pemData))
if decodedData == nil {
return nil, errors.New("Could not decode PEM data from input.")
return nil, errors.New("could not decode PEM data from input")
}
x509Cert, err := x509.ParseCertificate(decodedData.Bytes)
if err != nil {
return nil, errors.New("Could not parse X509 data from input.")
return nil, errors.New("could not parse X509 data from input")
}
return x509Cert, nil
}
Expand Down Expand Up @@ -171,7 +171,7 @@ func ParseTLSCertificatesFromStream(stream io.Reader) ([]string, error) {
}

if certLine > CertificateMaxLines {
return nil, errors.New("Maximum number of lines exceeded during certificate parsing.")
return nil, errors.New("maximum number of lines exceeded during certificate parsing")
}
}

Expand Down Expand Up @@ -233,7 +233,7 @@ func IsValidSSHKnownHostsEntry(line string) bool {
func TokenizeSSHKnownHostsEntry(knownHostsEntry string) (string, string, []byte, error) {
knownHostsToken := strings.SplitN(knownHostsEntry, " ", 3)
if len(knownHostsToken) != 3 {
return "", "", nil, fmt.Errorf("Error while tokenizing input data.")
return "", "", nil, fmt.Errorf("error while tokenizing input data")
}
return knownHostsToken[0], knownHostsToken[1], []byte(knownHostsToken[2]), nil
}
Expand Down Expand Up @@ -301,7 +301,17 @@ func ServerNameWithoutPort(serverName string) string {
// Load certificate data from a file. If the file does not exist, we do not
// consider it an error and just return empty data.
func GetCertificateForConnect(serverName string) ([]string, error) {
certPath := fmt.Sprintf("%s/%s", GetTLSCertificateDataPath(), ServerNameWithoutPort(serverName))
dataPath := GetTLSCertificateDataPath()
if !strings.HasSuffix(dataPath, "/") {
dataPath += "/"
}
certPath, err := filepath.Abs(filepath.Join(dataPath, ServerNameWithoutPort(serverName)))
if err != nil {
return nil, err
}
if !strings.HasPrefix(certPath, dataPath) {
return nil, fmt.Errorf("could not get certificate for host %s", serverName)
}
certificates, err := ParseTLSCertificatesFromPath(certPath)
if err != nil {
if os.IsNotExist(err) {
Expand All @@ -312,7 +322,7 @@ func GetCertificateForConnect(serverName string) ([]string, error) {
}

if len(certificates) == 0 {
return nil, fmt.Errorf("No certificates found in existing file.")
return nil, fmt.Errorf("no certificates found in existing file")
}

return certificates, nil
Expand Down
2 changes: 1 addition & 1 deletion util/cert/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ func TestGetCertificateForConnect(t *testing.T) {
certs, err := GetCertificateForConnect("127.0.0.1")
assert.Error(t, err)
assert.Len(t, certs, 0)
assert.Contains(t, err.Error(), "No certificates found")
assert.Contains(t, err.Error(), "no certificates found")
})

}
Expand Down

0 comments on commit a3583cb

Please sign in to comment.