Skip to content

Commit

Permalink
backport of commit 195dfca (#28791)
Browse files Browse the repository at this point in the history
Co-authored-by: miagilepner <mia.epner@hashicorp.com>
  • Loading branch information
1 parent 328fbc2 commit d92fac4
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 32 deletions.
5 changes: 4 additions & 1 deletion vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/go-secure-stdlib/tlsutil"
"github.com/hashicorp/go-uuid"
lru "github.com/hashicorp/golang-lru/v2"
kv "github.com/hashicorp/vault-plugin-secrets-kv"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit"
Expand Down Expand Up @@ -628,7 +629,9 @@ type Core struct {
// Stop channel for raft TLS rotations
raftTLSRotationStopCh chan struct{}
// Stores the pending peers we are waiting to give answers
pendingRaftPeers *sync.Map
pendingRaftPeers *lru.Cache[string, *raftBootstrapChallenge]
// holds the lock for modifying pendingRaftPeers
pendingRaftPeersLock sync.RWMutex

// rawConfig stores the config as-is from the provided server configuration.
rawConfig *atomic.Value
Expand Down
166 changes: 166 additions & 0 deletions vault/external_tests/raft/raft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"context"
"crypto/md5"
"encoding/base64"
"errors"
"fmt"
"io"
Expand All @@ -16,7 +17,9 @@ import (
"testing"
"time"

"github.com/golang/protobuf/proto"
"github.com/hashicorp/go-cleanhttp"
wrapping "github.com/hashicorp/go-kms-wrapping/v2"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
Expand Down Expand Up @@ -248,6 +251,169 @@ func TestRaft_Retry_Join(t *testing.T) {
})
}

// TestRaftChallenge_sameAnswerSameID_concurrent verifies that 10 goroutines
// all requesting a raft challenge with the same ID all return the same answer.
// This is a regression test for a TOCTTOU race found during testing.
func TestRaftChallenge_sameAnswerSameID_concurrent(t *testing.T) {
t.Parallel()

cluster, _ := raftCluster(t, &RaftClusterOpts{
DisableFollowerJoins: true,
NumCores: 1,
})
defer cluster.Cleanup()
client := cluster.Cores[0].Client

challenges := make(chan string, 15)
wg := sync.WaitGroup{}
for i := 0; i < 15; i++ {
wg.Add(1)
go func() {
defer wg.Done()
res, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{
"server_id": "node1",
})
require.NoError(t, err)
challenges <- res.Data["challenge"].(string)
}()
}

wg.Wait()
challengeSet := make(map[string]struct{})
close(challenges)
for challenge := range challenges {
challengeSet[challenge] = struct{}{}
}

require.Len(t, challengeSet, 1)
}

// TestRaftChallenge_sameAnswerSameID verifies that repeated bootstrap requests
// with the same node ID return the same challenge, but that a different node ID
// returns a different challenge
func TestRaftChallenge_sameAnswerSameID(t *testing.T) {
t.Parallel()

cluster, _ := raftCluster(t, &RaftClusterOpts{
DisableFollowerJoins: true,
NumCores: 1,
})
defer cluster.Cleanup()
client := cluster.Cores[0].Client
res, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{
"server_id": "node1",
})
require.NoError(t, err)

// querying the same ID returns the same challenge
challenge := res.Data["challenge"]
resSameID, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{
"server_id": "node1",
})
require.NoError(t, err)
require.Equal(t, challenge, resSameID.Data["challenge"])

// querying a different ID returns a new challenge
resDiffID, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{
"server_id": "node2",
})
require.NoError(t, err)
require.NotEqual(t, challenge, resDiffID.Data["challenge"])
}

// TestRaftChallenge_evicted verifies that a valid answer errors if there have
// been more than 20 challenge requests after it, because our cache of pending
// bootstraps is limited to 20
func TestRaftChallenge_evicted(t *testing.T) {
t.Parallel()
cluster, _ := raftCluster(t, &RaftClusterOpts{
DisableFollowerJoins: true,
NumCores: 1,
})
defer cluster.Cleanup()
firstResponse := map[string]interface{}{}
client := cluster.Cores[0].Client
for i := 0; i < vault.RaftInitialChallengeLimit+1; i++ {
if i == vault.RaftInitialChallengeLimit {
// wait before sending the last request, so we don't get rate
// limited
time.Sleep(2 * time.Second)
}
res, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{
"server_id": fmt.Sprintf("node-%d", i),
})
require.NoError(t, err)

// save the response from the first challenge
if i == 0 {
firstResponse = res.Data
}
}

// get the answer to the challenge
challengeRaw, err := base64.StdEncoding.DecodeString(firstResponse["challenge"].(string))
require.NoError(t, err)
eBlob := &wrapping.BlobInfo{}
err = proto.Unmarshal(challengeRaw, eBlob)
require.NoError(t, err)
access := cluster.Cores[0].SealAccess().GetAccess()
multiWrapValue := &vaultseal.MultiWrapValue{
Generation: access.Generation(),
Slots: []*wrapping.BlobInfo{eBlob},
}
plaintext, _, err := access.Decrypt(context.Background(), multiWrapValue)
require.NoError(t, err)

// send the answer
_, err = client.Logical().Write("sys/storage/raft/bootstrap/answer", map[string]interface{}{
"answer": base64.StdEncoding.EncodeToString(plaintext),
"server_id": "node-0",
"cluster_addr": "127.0.0.1:8200",
"sdk_version": "1.1.1",
"upgrade_version": "1.2.3",
"non_voter": false,
})

require.ErrorContains(t, err, "no expected answer for the server id provided")
}

// TestRaft_ChallengeSpam creates 40 raft bootstrap challenges. The first 20
// should succeed. After 20 challenges have been created, slow down the requests
// so that there are 2.5 occurring per second. Some of these will fail, due to
// rate limiting, but others will succeed.
func TestRaft_ChallengeSpam(t *testing.T) {
t.Parallel()
cluster, _ := raftCluster(t, &RaftClusterOpts{
DisableFollowerJoins: true,
})
defer cluster.Cleanup()

// Execute 2 * MaxInFlightRequests, over a period that should allow some to proceed as the token bucket
// refills.
var someLaterFailed bool
var someLaterSucceeded bool
for n := 0; n < 2*vault.RaftInitialChallengeLimit; n++ {
_, err := cluster.Cores[0].Client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{
"server_id": fmt.Sprintf("core-%d", n),
})
// First MaxInFlightRequests should succeed for sure
if n < vault.RaftInitialChallengeLimit {
require.NoError(t, err)
} else {
// slow down to twice the configured rps
time.Sleep((1000 * time.Millisecond) / (2 * time.Duration(vault.RaftChallengesPerSecond)))
if err != nil {
require.Equal(t, 429, err.(*api.ResponseError).StatusCode)
someLaterFailed = true
} else {
someLaterSucceeded = true
}
}
}
require.True(t, someLaterFailed)
require.True(t, someLaterSucceeded)
}

func TestRaft_Join(t *testing.T) {
t.Parallel()
cluster, _ := raftCluster(t, &RaftClusterOpts{
Expand Down
23 changes: 13 additions & 10 deletions vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"github.com/hashicorp/vault/version"
"github.com/mitchellh/mapstructure"
"golang.org/x/crypto/sha3"
"golang.org/x/time/rate"
)

const (
Expand Down Expand Up @@ -94,11 +95,12 @@ func NewSystemBackend(core *Core, logger log.Logger, config *logical.BackendConf
}

b := &SystemBackend{
Core: core,
db: db,
logger: logger,
mfaBackend: NewPolicyMFABackend(core, logger),
syncBackend: syncBackend,
Core: core,
db: db,
logger: logger,
mfaBackend: NewPolicyMFABackend(core, logger),
syncBackend: syncBackend,
raftChallengeLimiter: rate.NewLimiter(rate.Limit(RaftChallengesPerSecond), RaftInitialChallengeLimit),
}

b.Backend = &framework.Backend{
Expand Down Expand Up @@ -270,11 +272,12 @@ func (b *SystemBackend) rawPaths() []*framework.Path {
type SystemBackend struct {
*framework.Backend
entSystemBackend
Core *Core
db *memdb.MemDB
logger log.Logger
mfaBackend *PolicyMFABackend
syncBackend *SecretsSyncBackend
Core *Core
db *memdb.MemDB
logger log.Logger
mfaBackend *PolicyMFABackend
syncBackend *SecretsSyncBackend
raftChallengeLimiter *rate.Limiter
}

// handleConfigStateSanitized returns the current configuration state. The configuration
Expand Down
68 changes: 48 additions & 20 deletions vault/logical_system_raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"net/http"
"strings"
"time"

Expand Down Expand Up @@ -272,6 +273,10 @@ func (b *SystemBackend) handleRaftRemovePeerUpdate() framework.OperationFunc {
}
}

const answerSize = 16

var answerMaxEncodedSize = base64.StdEncoding.EncodedLen(answerSize)

func (b *SystemBackend) handleRaftBootstrapChallengeWrite(makeSealer func() snapshot.Sealer) framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
serverID := d.Get("server_id").(string)
Expand All @@ -280,25 +285,42 @@ func (b *SystemBackend) handleRaftBootstrapChallengeWrite(makeSealer func() snap
}

var answer []byte
answerRaw, ok := b.Core.pendingRaftPeers.Load(serverID)
b.Core.pendingRaftPeersLock.RLock()
challenge, ok := b.Core.pendingRaftPeers.Get(serverID)
b.Core.pendingRaftPeersLock.RUnlock()
if !ok {
var err error
answer, err = uuid.GenerateRandomBytes(16)
if err != nil {
return nil, err
if !b.raftChallengeLimiter.Allow() {
return logical.RespondWithStatusCode(logical.ErrorResponse("too many raft challenges in flight"), req, http.StatusTooManyRequests)
}
b.Core.pendingRaftPeers.Store(serverID, answer)
} else {
answer = answerRaw.([]byte)
}

sealer := makeSealer()
if sealer == nil {
return nil, errors.New("core has no seal Access to write raft bootstrap challenge")
}
protoBlob, err := sealer.Seal(ctx, answer)
if err != nil {
return nil, err
b.Core.pendingRaftPeersLock.Lock()
defer b.Core.pendingRaftPeersLock.Unlock()

challenge, ok = b.Core.pendingRaftPeers.Get(serverID)
if !ok {

var err error
answer, err = uuid.GenerateRandomBytes(answerSize)
if err != nil {
return nil, err
}

sealer := makeSealer()
if sealer == nil {
return nil, errors.New("core has no seal access to write raft bootstrap challenge")
}
protoBlob, err := sealer.Seal(ctx, answer)
if err != nil {
return nil, err
}

challenge = &raftBootstrapChallenge{
serverID: serverID,
answer: answer,
challenge: protoBlob,
}
b.Core.pendingRaftPeers.Add(serverID, challenge)
}
}

sealConfig, err := b.Core.seal.BarrierConfig(ctx)
Expand All @@ -308,7 +330,7 @@ func (b *SystemBackend) handleRaftBootstrapChallengeWrite(makeSealer func() snap

return &logical.Response{
Data: map[string]interface{}{
"challenge": base64.StdEncoding.EncodeToString(protoBlob),
"challenge": base64.StdEncoding.EncodeToString(challenge.challenge),
"seal_config": sealConfig,
},
}, nil
Expand All @@ -330,6 +352,9 @@ func (b *SystemBackend) handleRaftBootstrapAnswerWrite() framework.OperationFunc
if len(answerRaw) == 0 {
return logical.ErrorResponse("no answer provided"), logical.ErrInvalidRequest
}
if len(answerRaw) > answerMaxEncodedSize {
return logical.ErrorResponse("answer is too long"), logical.ErrInvalidRequest
}
clusterAddr := d.Get("cluster_addr").(string)
if len(clusterAddr) == 0 {
return logical.ErrorResponse("no cluster_addr provided"), logical.ErrInvalidRequest
Expand All @@ -342,14 +367,17 @@ func (b *SystemBackend) handleRaftBootstrapAnswerWrite() framework.OperationFunc
return logical.ErrorResponse("could not base64 decode answer"), logical.ErrInvalidRequest
}

expectedAnswerRaw, ok := b.Core.pendingRaftPeers.Load(serverID)
b.Core.pendingRaftPeersLock.Lock()
expectedChallenge, ok := b.Core.pendingRaftPeers.Get(serverID)
if !ok {
b.Core.pendingRaftPeersLock.Unlock()
return logical.ErrorResponse("no expected answer for the server id provided"), logical.ErrInvalidRequest
}

b.Core.pendingRaftPeers.Delete(serverID)
b.Core.pendingRaftPeers.Remove(serverID)
b.Core.pendingRaftPeersLock.Unlock()

if subtle.ConstantTimeCompare(answer, expectedAnswerRaw.([]byte)) == 0 {
if subtle.ConstantTimeCompare(answer, expectedChallenge.answer) == 0 {
return logical.ErrorResponse("invalid answer given"), logical.ErrInvalidRequest
}

Expand Down
Loading

0 comments on commit d92fac4

Please sign in to comment.