Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start counting ACME certificate issuance as client activity #20520

Merged
merged 12 commits into from
May 17, 2023
25 changes: 25 additions & 0 deletions builtin/logical/pki/acme_billing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package pki

import (
"context"
"fmt"

"github.com/hashicorp/vault/sdk/logical"
)

func (b *backend) doTrackBilling(ctx context.Context, identifiers []*ACMEIdentifier) error {
billingView, ok := b.System().(logical.ACMEBillingSystemView)
if !ok {
return fmt.Errorf("failed to perform cast to ACME billing system view interface")
}

var realized []string
for _, identifier := range identifiers {
realized = append(realized, fmt.Sprintf("%s/%s", identifier.Type, identifier.OriginalValue))
}

return billingView.CreateActivityCountEventForIdentifiers(ctx, realized)
}
296 changes: 296 additions & 0 deletions builtin/logical/pki/acme_billing_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package pki

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"strings"
"testing"
"time"

"golang.org/x/crypto/acme"

"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/pki/dnstest"
"github.com/hashicorp/vault/helper/constants"
"github.com/hashicorp/vault/helper/timeutil"

"github.com/stretchr/testify/require"
)

// TestACMEBilling is a basic test that will validate client counts created via ACME workflows.
func TestACMEBilling(t *testing.T) {
t.Parallel()
timeutil.SkipAtEndOfMonth(t)

cluster, client, _ := setupAcmeBackend(t)
defer cluster.Cleanup()

dns := dnstest.SetupResolver(t, "dadgarcorp.com")
defer dns.Cleanup()

// Enable additional mounts.
setupAcmeBackendOnClusterAtPath(t, cluster, client, "pki2")
setupAcmeBackendOnClusterAtPath(t, cluster, client, "ns1/pki")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this make sense in OSS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OSS doesn't have namespaces, but it does have nested paths. So this does work and mounts it at ns1/pki, it just isn't a true Enterprise namespace.

setupAcmeBackendOnClusterAtPath(t, cluster, client, "ns2/pki")

// Enable custom DNS resolver for testing.
for _, mount := range []string{"pki", "pki2", "ns1/pki", "ns2/pki"} {
_, err := client.Logical().Write(mount+"/config/acme", map[string]interface{}{
"dns_resolver": dns.GetLocalAddr(),
})
require.NoError(t, err, "failed to set local dns resolver address for testing on mount: "+mount)
}

// Enable client counting.
_, err := client.Logical().Write("/sys/internal/counters/config", map[string]interface{}{
"enabled": "enable",
})
require.NoError(t, err, "failed to enable client counting")

// Setup ACME clients. We refresh account keys each time for consistency.
acmeClientPKI := getAcmeClientForCluster(t, cluster, "/v1/pki/acme/", nil)
acmeClientPKI2 := getAcmeClientForCluster(t, cluster, "/v1/pki2/acme/", nil)
acmeClientPKINS1 := getAcmeClientForCluster(t, cluster, "/v1/ns1/pki/acme/", nil)
acmeClientPKINS2 := getAcmeClientForCluster(t, cluster, "/v1/ns2/pki/acme/", nil)

// Get our initial count.
expectedCount := validateClientCount(t, client, "", -1, "initial fetch")

// Unique identifier: should increase by one.
doACMEForDomainWithDNS(t, dns, &acmeClientPKI, []string{"dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "pki", expectedCount+1, "new certificate")

// Different identifier; should increase by one.
doACMEForDomainWithDNS(t, dns, &acmeClientPKI, []string{"example.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "pki", expectedCount+1, "new certificate")

// While same identifiers, used together and so thus are unique; increase by one.
doACMEForDomainWithDNS(t, dns, &acmeClientPKI, []string{"example.dadgarcorp.com", "dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "pki", expectedCount+1, "new certificate")

// Same identifiers in different order are not unique; keep the same.
doACMEForDomainWithDNS(t, dns, &acmeClientPKI, []string{"dadgarcorp.com", "example.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "pki", expectedCount, "different order; same identifiers")

// Using a different mount shouldn't affect counts.
doACMEForDomainWithDNS(t, dns, &acmeClientPKI2, []string{"dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "", expectedCount, "different mount; same identifiers")

// But using a different identifier should.
doACMEForDomainWithDNS(t, dns, &acmeClientPKI2, []string{"pki2.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "pki2", expectedCount+1, "different mount with different identifiers")

// A new identifier in a unique namespace will affect results.
doACMEForDomainWithDNS(t, dns, &acmeClientPKINS1, []string{"unique.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "ns1/pki", expectedCount+1, "unique identifier in a namespace")

// But in a different namespace with the existing identifier will not.
doACMEForDomainWithDNS(t, dns, &acmeClientPKINS2, []string{"unique.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "", expectedCount, "existing identifier in a namespace")
doACMEForDomainWithDNS(t, dns, &acmeClientPKI2, []string{"unique.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "", expectedCount, "existing identifier outside of a namespace")

// Creating a unique identifier in a namespace with a mount with the
// same name as another namespace should increase counts as well.
doACMEForDomainWithDNS(t, dns, &acmeClientPKINS2, []string{"very-unique.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "ns2/pki", expectedCount+1, "unique identifier in a different namespace")
}

func validateClientCount(t *testing.T, client *api.Client, mount string, expected int64, message string) int64 {
resp, err := client.Logical().Read("/sys/internal/counters/activity/monthly")
require.NoError(t, err, "failed to fetch client count values")
t.Logf("got client count numbers: %v", resp)

require.NotNil(t, resp)
require.NotNil(t, resp.Data)
require.Contains(t, resp.Data, "non_entity_clients")
require.Contains(t, resp.Data, "months")

rawCount := resp.Data["non_entity_clients"].(json.Number)
count, err := rawCount.Int64()
require.NoError(t, err, "failed to parse number as int64: "+rawCount.String())

if expected != -1 {
require.Equal(t, expected, count, "value of client counts did not match expectations: "+message)
cipherboy marked this conversation as resolved.
Show resolved Hide resolved
}

if mount == "" {
return count
}

months := resp.Data["months"].([]interface{})
if len(months) > 1 {
t.Fatalf("running across a month boundary despite using SkipAtEndOfMonth(...); rerun test from start fully in the next month instead")
}

require.Equal(t, 1, len(months), "expected only a single month when running this test")

monthlyInfo := months[0].(map[string]interface{})

// Validate this month's aggregate counts match the overall value.
require.Contains(t, monthlyInfo, "counts", "expected monthly info to contain a count key")
monthlyCounts := monthlyInfo["counts"].(map[string]interface{})
require.Contains(t, monthlyCounts, "non_entity_clients", "expected month[0].counts to contain a non_entity_clients key")
monthlyCountNonEntityRaw := monthlyCounts["non_entity_clients"].(json.Number)
monthlyCountNonEntity, err := monthlyCountNonEntityRaw.Int64()
require.NoError(t, err, "failed to parse number as int64: "+monthlyCountNonEntityRaw.String())
require.Equal(t, count, monthlyCountNonEntity, "expected equal values for non entity client counts")

// Validate this mount's namespace is included in the namespaces list,
// if this is enterprise. Otherwise, if its OSS or we don't have a
// namespace, we default to the value root.
mountNamespace := "root"
mountPath := mount + "/"
if constants.IsEnterprise && strings.Contains(mount, "/") {
pieces := strings.Split(mount, "/")
require.Equal(t, 2, len(pieces), "we do not support nested namespaces in this test")
mountNamespace = pieces[0]
mountPath = pieces[1] + "/"
}

require.Contains(t, monthlyInfo, "namespaces", "expected monthly info to contain a namespaces key")
monthlyNamespaces := monthlyInfo["namespaces"].([]interface{})
foundNamespace := false
for index, namespaceRaw := range monthlyNamespaces {
namespace := namespaceRaw.(map[string]interface{})
require.Contains(t, namespace, "namespace_id", "expected monthly.namespaces[%v] to contain a namespace_id key", index)
namespaceId := namespace["namespace_id"].(string)

if namespaceId != mountNamespace {
t.Logf("skipping non-matching namespace %v: %v != %v / %v", index, namespaceId, mountNamespace, namespace)
continue
}

foundNamespace = true

// This namespace must have a non-empty aggregate non-entity count.
require.Contains(t, namespace, "counts", "expected monthly.namespaces[%v] to contain a counts key", index)
namespaceCounts := namespace["counts"].(map[string]interface{})
require.Contains(t, namespaceCounts, "non_entity_clients", "expected namespace counts to contain a non_entity_clients key")
namespaceCountNonEntityRaw := namespaceCounts["non_entity_clients"].(json.Number)
namespaceCountNonEntity, err := namespaceCountNonEntityRaw.Int64()
require.NoError(t, err, "failed to parse number as int64: "+namespaceCountNonEntityRaw.String())
require.Greater(t, namespaceCountNonEntity, int64(0), "expected at least one non-entity client count value in the namespace")

require.Contains(t, namespace, "mounts", "expected monthly.namespaces[%v] to contain a mounts key", index)
namespaceMounts := namespace["mounts"].([]interface{})
foundMount := false
for mountIndex, mountRaw := range namespaceMounts {
mountInfo := mountRaw.(map[string]interface{})
require.Contains(t, mountInfo, "mount_path", "expected monthly.namespaces[%v].mounts[%v] to contain a mount_path key", index, mountIndex)
mountInfoPath := mountInfo["mount_path"].(string)
if mountPath != mountInfoPath {
t.Logf("skipping non-matching mount path %v in namespace %v: %v != %v / %v of %v", mountIndex, index, mountPath, mountInfoPath, mountInfo, namespace)
continue
}

foundMount = true

// This mount must also have a non-empty non-entity client count.
require.Contains(t, mountInfo, "counts", "expected monthly.namespaces[%v].mounts[%v] to contain a counts key", index, mountIndex)
mountCounts := mountInfo["counts"].(map[string]interface{})
require.Contains(t, mountCounts, "non_entity_clients", "expected mount counts to contain a non_entity_clients key")
mountCountNonEntityRaw := mountCounts["non_entity_clients"].(json.Number)
mountCountNonEntity, err := mountCountNonEntityRaw.Int64()
require.NoError(t, err, "failed to parse number as int64: "+mountCountNonEntityRaw.String())
require.Greater(t, mountCountNonEntity, int64(0), "expected at least one non-entity client count value in the mount")
}

require.True(t, foundMount, "expected to find the mount "+mountPath+" in the list of mounts for namespace, but did not")
}

require.True(t, foundNamespace, "expected to find the namespace "+mountNamespace+" in the list of namespaces, but did not")

return count
}

func doACMEForDomainWithDNS(t *testing.T, dns *dnstest.TestServer, acmeClient *acme.Client, domains []string) *x509.Certificate {
cr := &x509.CertificateRequest{
Subject: pkix.Name{CommonName: domains[0]},
DNSNames: domains,
}

accountKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err, "failed to generate account key")
acmeClient.Key = accountKey

testCtx, cancelFunc := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancelFunc()

// Register the client.
_, err = acmeClient.Register(testCtx, &acme.Account{Contact: []string{"mailto:ipsans@dadgarcorp.com"}}, func(tosURL string) bool { return true })
require.NoError(t, err, "failed registering account")

// Create the Order
var orderIdentifiers []acme.AuthzID
for _, domain := range domains {
orderIdentifiers = append(orderIdentifiers, acme.AuthzID{Type: "dns", Value: domain})
}
order, err := acmeClient.AuthorizeOrder(testCtx, orderIdentifiers)
require.NoError(t, err, "failed creating ACME order")

// Fetch its authorizations.
var auths []*acme.Authorization
for _, authUrl := range order.AuthzURLs {
authorization, err := acmeClient.GetAuthorization(testCtx, authUrl)
require.NoError(t, err, "failed to lookup authorization at url: %s", authUrl)
auths = append(auths, authorization)
}

// For each dns-01 challenge, place the record in the associated DNS resolver.
var challengesToAccept []*acme.Challenge
for _, auth := range auths {
for _, challenge := range auth.Challenges {
if challenge.Status != acme.StatusPending {
t.Logf("ignoring challenge not in status pending: %v", challenge)
continue
}

if challenge.Type == "dns-01" {
challengeBody, err := acmeClient.DNS01ChallengeRecord(challenge.Token)
require.NoError(t, err, "failed generating challenge response")

dns.AddRecord("_acme-challenge."+auth.Identifier.Value, "TXT", challengeBody)
defer dns.RemoveRecord("_acme-challenge."+auth.Identifier.Value, "TXT", challengeBody)

require.NoError(t, err, "failed setting DNS record")

challengesToAccept = append(challengesToAccept, challenge)
}
}
}

dns.PushConfig()
require.GreaterOrEqual(t, len(challengesToAccept), 1, "Need at least one challenge, got none")

// Tell the ACME server, that they can now validate those challenges.
for _, challenge := range challengesToAccept {
_, err = acmeClient.Accept(testCtx, challenge)
require.NoError(t, err, "failed to accept challenge: %v", challenge)
}

// Wait for the order/challenges to be validated.
_, err = acmeClient.WaitOrder(testCtx, order.URI)
require.NoError(t, err, "failed waiting for order to be ready")

// Create/sign the CSR and ask ACME server to sign it returning us the final certificate
csrKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
csr, err := x509.CreateCertificateRequest(rand.Reader, cr, csrKey)
require.NoError(t, err, "failed generating csr")

certs, _, err := acmeClient.CreateOrderCert(testCtx, order.FinalizeURL, csr, false)
require.NoError(t, err, "failed to get a certificate back from ACME")

acmeCert, err := x509.ParseCertificate(certs[0])
require.NoError(t, err, "failed parsing acme cert bytes")

return acmeCert
}
5 changes: 5 additions & 0 deletions builtin/logical/pki/path_acme_order.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, _ *logical.Request,
return nil, fmt.Errorf("failed saving updated order: %w", err)
}

if err := b.doTrackBilling(ac.sc.Context, order.Identifiers); err != nil {
b.Logger().Error("failed to track billing for order", "order", orderId, "error", err)
err = nil
}

return formatOrderResponse(ac, order), nil
}

Expand Down
Loading