Skip to content

Commit

Permalink
Merge pull request #1865 from kevinschoonover/kevinschoonover/fix-pki…
Browse files Browse the repository at this point in the history
…-rotation

use lifespan instead of duration for calculating when cert should be rotate
  • Loading branch information
divyaac authored Apr 19, 2024
2 parents ee19d70 + d876ab2 commit 8e8026b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 19 deletions.
34 changes: 17 additions & 17 deletions dependency/vault_pki.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,26 +118,26 @@ func goodFor(cert *x509.Certificate) (time.Duration, bool) {
if cert == nil {
return 0, false
}
// These are all int64's with Seconds since the Epoch, handy for the math
start, end := cert.NotBefore.Unix(), cert.NotAfter.Unix()
now := time.Now().UTC().Unix()
if end <= now { // already expired
start, end := cert.NotBefore.UTC(), cert.NotAfter.UTC()
now := time.Now().UTC()
if end.Before(now) || end.Equal(now) { // already expired
return 0, false
}
lifespan := end - start // full ttl of cert
duration := end - now // duration remaining
gooddur := (duration * 9) / 10 // 90% of duration
mindur := (lifespan / 10) // 10% of lifespan
if gooddur <= mindur {
return 0, false // almost expired, get a new one
}
if gooddur > 100 { // 100 seconds
// add jitter if big enough for it to matter
r := rand.New(rand.NewSource(time.Now().UnixNano()))
// between 87% and 93%
gooddur = gooddur + ((gooddur / 100) * int64(r.Intn(6)-3))

lifespanDur := end.Sub(start)
r := rand.New(rand.NewSource(time.Now().UnixNano()))
lifespanMilliseconds := lifespanDur.Milliseconds()
// calculate the 'time the certificate should be rotated' by figuring out
// 87-93% of the lifespan and adding it to the start
rotationTime := start.Add(time.Millisecond * time.Duration(((lifespanMilliseconds*9)/10)+(lifespanMilliseconds*int64(r.Intn(6)-3))/100))

// after we have the 'time the certificate should be rotated', figure out how
// far it is from now to sleep
sleepFor := time.Duration(rotationTime.Sub(now))
if sleepFor <= 0 {
return 0, false
}
sleepFor := time.Duration(gooddur * 1e9) // basically: gooddur*time.Second

return sleepFor, true
}

Expand Down
57 changes: 55 additions & 2 deletions dependency/vault_pki_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ package dependency

import (
"bytes"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"os"
"strings"
"testing"
"time"

"github.com/hashicorp/consul-template/renderer"
"github.com/hashicorp/vault/api"
Expand Down Expand Up @@ -53,6 +57,46 @@ func Test_VaultPKI_notGoodFor(t *testing.T) {
}
}

func Test_VaulkPKI_goodFor(t *testing.T) {
tests := map[string]struct {
CertificateTTL time.Duration
}{
"one minute": {CertificateTTL: time.Minute},
"one hour": {CertificateTTL: time.Hour},
"one day": {CertificateTTL: time.Hour * 24},
"one week": {CertificateTTL: time.Hour * 24 * 7},
}
for name, tc := range tests {
NotBefore := time.Now()
NotAfter := time.Now().Add(tc.CertificateTTL)
certificate := x509.Certificate{
Subject: pkix.Name{
Organization: []string{"Acme Co"},
},
NotBefore: NotBefore,
NotAfter: NotAfter,
}

dur, ok := goodFor(&certificate)
if ok == false {
t.Errorf("%v: should be true", name)
}

ratio := dur.Seconds() / (NotAfter.Sub(NotBefore).Seconds())
// allow for a .01 epsilon for floating point comparison to prevent flakey tests
if ratio < .86 || ratio > .94 {
fmt.Println(ratio)
t.Errorf(
"%v: should be between 87 and 93, but was %.2f. NotBefore: %s, NotAfter: %s",
name,
ratio,
NotBefore,
NotAfter,
)
}
}
}

func Test_VaultPKI_pemsCert(t *testing.T) {
// tests w/ valid pems, and having it hidden behind various things
want := strings.TrimRight(strings.TrimSpace(validCert), "\n")
Expand Down Expand Up @@ -136,10 +180,16 @@ func Test_VaultPKI_refetch(t *testing.T) {
defer os.Remove(f.Name())

clients := testClients
TTL := "2s"
ttlDuration, err := time.ParseDuration(TTL)
if err != nil {
t.Fatal(err)
}

/// above is prep work
data := map[string]interface{}{
"common_name": "foo.example.com",
"ttl": "3s",
"ttl": TTL,
"ip_sans": "127.0.0.1,192.168.2.2",
}
d, err := NewVaultPKIQuery("pki/issue/example-dot-com", f.Name(), data)
Expand Down Expand Up @@ -189,7 +239,10 @@ func Test_VaultPKI_refetch(t *testing.T) {
t.Errorf("pemss don't match and should.")
}

// Don't pre-drain here as we want it to get a new pems
// forcefully wait the longest the certificate could be good force to ensure
// goodFor will always return needs renewal
<-d.sleepCh
time.Sleep(time.Millisecond * time.Duration(((ttlDuration.Milliseconds()*9)/10)+(ttlDuration.Milliseconds()*int64(3)/100)))
act3, rm, err := d.Fetch(clients, nil)
if err != nil {
t.Fatal(err)
Expand Down

0 comments on commit 8e8026b

Please sign in to comment.