Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport of VAULT-31264: Limit raft joins into release/1.18.x #28791

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/go-secure-stdlib/tlsutil"
"github.com/hashicorp/go-uuid"
lru "github.com/hashicorp/golang-lru/v2"
kv "github.com/hashicorp/vault-plugin-secrets-kv"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit"
Expand Down Expand Up @@ -628,7 +629,9 @@ type Core struct {
// Stop channel for raft TLS rotations
raftTLSRotationStopCh chan struct{}
// Stores the pending peers we are waiting to give answers
pendingRaftPeers *sync.Map
pendingRaftPeers *lru.Cache[string, *raftBootstrapChallenge]
// holds the lock for modifying pendingRaftPeers
pendingRaftPeersLock sync.RWMutex

// rawConfig stores the config as-is from the provided server configuration.
rawConfig *atomic.Value
Expand Down
166 changes: 166 additions & 0 deletions vault/external_tests/raft/raft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"context"
"crypto/md5"
"encoding/base64"
"errors"
"fmt"
"io"
Expand All @@ -16,7 +17,9 @@ import (
"testing"
"time"

"github.com/golang/protobuf/proto"
"github.com/hashicorp/go-cleanhttp"
wrapping "github.com/hashicorp/go-kms-wrapping/v2"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
Expand Down Expand Up @@ -248,6 +251,169 @@ func TestRaft_Retry_Join(t *testing.T) {
})
}

// TestRaftChallenge_sameAnswerSameID_concurrent verifies that 10 goroutines
// all requesting a raft challenge with the same ID all return the same answer.
// This is a regression test for a TOCTTOU race found during testing.
func TestRaftChallenge_sameAnswerSameID_concurrent(t *testing.T) {
t.Parallel()

cluster, _ := raftCluster(t, &RaftClusterOpts{
DisableFollowerJoins: true,
NumCores: 1,
})
defer cluster.Cleanup()
client := cluster.Cores[0].Client

challenges := make(chan string, 15)
wg := sync.WaitGroup{}
for i := 0; i < 15; i++ {
wg.Add(1)
go func() {
defer wg.Done()
res, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{
"server_id": "node1",
})
require.NoError(t, err)
challenges <- res.Data["challenge"].(string)
}()
}

wg.Wait()
challengeSet := make(map[string]struct{})
close(challenges)
for challenge := range challenges {
challengeSet[challenge] = struct{}{}
}

require.Len(t, challengeSet, 1)
}

// TestRaftChallenge_sameAnswerSameID verifies that repeated bootstrap requests
// with the same node ID return the same challenge, but that a different node ID
// returns a different challenge
func TestRaftChallenge_sameAnswerSameID(t *testing.T) {
t.Parallel()

cluster, _ := raftCluster(t, &RaftClusterOpts{
DisableFollowerJoins: true,
NumCores: 1,
})
defer cluster.Cleanup()
client := cluster.Cores[0].Client
res, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{
"server_id": "node1",
})
require.NoError(t, err)

// querying the same ID returns the same challenge
challenge := res.Data["challenge"]
resSameID, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{
"server_id": "node1",
})
require.NoError(t, err)
require.Equal(t, challenge, resSameID.Data["challenge"])

// querying a different ID returns a new challenge
resDiffID, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{
"server_id": "node2",
})
require.NoError(t, err)
require.NotEqual(t, challenge, resDiffID.Data["challenge"])
}

// TestRaftChallenge_evicted verifies that a valid answer errors if there have
// been more than 20 challenge requests after it, because our cache of pending
// bootstraps is limited to 20
func TestRaftChallenge_evicted(t *testing.T) {
t.Parallel()
cluster, _ := raftCluster(t, &RaftClusterOpts{
DisableFollowerJoins: true,
NumCores: 1,
})
defer cluster.Cleanup()
firstResponse := map[string]interface{}{}
client := cluster.Cores[0].Client
for i := 0; i < vault.RaftInitialChallengeLimit+1; i++ {
if i == vault.RaftInitialChallengeLimit {
// wait before sending the last request, so we don't get rate
// limited
time.Sleep(2 * time.Second)
}
res, err := client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{
"server_id": fmt.Sprintf("node-%d", i),
})
require.NoError(t, err)

// save the response from the first challenge
if i == 0 {
firstResponse = res.Data
}
}

// get the answer to the challenge
challengeRaw, err := base64.StdEncoding.DecodeString(firstResponse["challenge"].(string))
require.NoError(t, err)
eBlob := &wrapping.BlobInfo{}
err = proto.Unmarshal(challengeRaw, eBlob)
require.NoError(t, err)
access := cluster.Cores[0].SealAccess().GetAccess()
multiWrapValue := &vaultseal.MultiWrapValue{
Generation: access.Generation(),
Slots: []*wrapping.BlobInfo{eBlob},
}
plaintext, _, err := access.Decrypt(context.Background(), multiWrapValue)
require.NoError(t, err)

// send the answer
_, err = client.Logical().Write("sys/storage/raft/bootstrap/answer", map[string]interface{}{
"answer": base64.StdEncoding.EncodeToString(plaintext),
"server_id": "node-0",
"cluster_addr": "127.0.0.1:8200",
"sdk_version": "1.1.1",
"upgrade_version": "1.2.3",
"non_voter": false,
})

require.ErrorContains(t, err, "no expected answer for the server id provided")
}

// TestRaft_ChallengeSpam creates 40 raft bootstrap challenges. The first 20
// should succeed. After 20 challenges have been created, slow down the requests
// so that there are 2.5 occurring per second. Some of these will fail, due to
// rate limiting, but others will succeed.
func TestRaft_ChallengeSpam(t *testing.T) {
t.Parallel()
cluster, _ := raftCluster(t, &RaftClusterOpts{
DisableFollowerJoins: true,
})
defer cluster.Cleanup()

// Execute 2 * MaxInFlightRequests, over a period that should allow some to proceed as the token bucket
// refills.
var someLaterFailed bool
var someLaterSucceeded bool
for n := 0; n < 2*vault.RaftInitialChallengeLimit; n++ {
_, err := cluster.Cores[0].Client.Logical().Write("sys/storage/raft/bootstrap/challenge", map[string]interface{}{
"server_id": fmt.Sprintf("core-%d", n),
})
// First MaxInFlightRequests should succeed for sure
if n < vault.RaftInitialChallengeLimit {
require.NoError(t, err)
} else {
// slow down to twice the configured rps
time.Sleep((1000 * time.Millisecond) / (2 * time.Duration(vault.RaftChallengesPerSecond)))
if err != nil {
require.Equal(t, 429, err.(*api.ResponseError).StatusCode)
someLaterFailed = true
} else {
someLaterSucceeded = true
}
}
}
require.True(t, someLaterFailed)
require.True(t, someLaterSucceeded)
}

func TestRaft_Join(t *testing.T) {
t.Parallel()
cluster, _ := raftCluster(t, &RaftClusterOpts{
Expand Down
23 changes: 13 additions & 10 deletions vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"github.com/hashicorp/vault/version"
"github.com/mitchellh/mapstructure"
"golang.org/x/crypto/sha3"
"golang.org/x/time/rate"
)

const (
Expand Down Expand Up @@ -94,11 +95,12 @@ func NewSystemBackend(core *Core, logger log.Logger, config *logical.BackendConf
}

b := &SystemBackend{
Core: core,
db: db,
logger: logger,
mfaBackend: NewPolicyMFABackend(core, logger),
syncBackend: syncBackend,
Core: core,
db: db,
logger: logger,
mfaBackend: NewPolicyMFABackend(core, logger),
syncBackend: syncBackend,
raftChallengeLimiter: rate.NewLimiter(rate.Limit(RaftChallengesPerSecond), RaftInitialChallengeLimit),
}

b.Backend = &framework.Backend{
Expand Down Expand Up @@ -270,11 +272,12 @@ func (b *SystemBackend) rawPaths() []*framework.Path {
type SystemBackend struct {
*framework.Backend
entSystemBackend
Core *Core
db *memdb.MemDB
logger log.Logger
mfaBackend *PolicyMFABackend
syncBackend *SecretsSyncBackend
Core *Core
db *memdb.MemDB
logger log.Logger
mfaBackend *PolicyMFABackend
syncBackend *SecretsSyncBackend
raftChallengeLimiter *rate.Limiter
}

// handleConfigStateSanitized returns the current configuration state. The configuration
Expand Down
68 changes: 48 additions & 20 deletions vault/logical_system_raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"net/http"
"strings"
"time"

Expand Down Expand Up @@ -272,6 +273,10 @@ func (b *SystemBackend) handleRaftRemovePeerUpdate() framework.OperationFunc {
}
}

const answerSize = 16

var answerMaxEncodedSize = base64.StdEncoding.EncodedLen(answerSize)

func (b *SystemBackend) handleRaftBootstrapChallengeWrite(makeSealer func() snapshot.Sealer) framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
serverID := d.Get("server_id").(string)
Expand All @@ -280,25 +285,42 @@ func (b *SystemBackend) handleRaftBootstrapChallengeWrite(makeSealer func() snap
}

var answer []byte
answerRaw, ok := b.Core.pendingRaftPeers.Load(serverID)
b.Core.pendingRaftPeersLock.RLock()
challenge, ok := b.Core.pendingRaftPeers.Get(serverID)
b.Core.pendingRaftPeersLock.RUnlock()
if !ok {
var err error
answer, err = uuid.GenerateRandomBytes(16)
if err != nil {
return nil, err
if !b.raftChallengeLimiter.Allow() {
return logical.RespondWithStatusCode(logical.ErrorResponse("too many raft challenges in flight"), req, http.StatusTooManyRequests)
}
b.Core.pendingRaftPeers.Store(serverID, answer)
} else {
answer = answerRaw.([]byte)
}

sealer := makeSealer()
if sealer == nil {
return nil, errors.New("core has no seal Access to write raft bootstrap challenge")
}
protoBlob, err := sealer.Seal(ctx, answer)
if err != nil {
return nil, err
b.Core.pendingRaftPeersLock.Lock()
defer b.Core.pendingRaftPeersLock.Unlock()

challenge, ok = b.Core.pendingRaftPeers.Get(serverID)
if !ok {

var err error
answer, err = uuid.GenerateRandomBytes(answerSize)
if err != nil {
return nil, err
}

sealer := makeSealer()
if sealer == nil {
return nil, errors.New("core has no seal access to write raft bootstrap challenge")
}
protoBlob, err := sealer.Seal(ctx, answer)
if err != nil {
return nil, err
}

challenge = &raftBootstrapChallenge{
serverID: serverID,
answer: answer,
challenge: protoBlob,
}
b.Core.pendingRaftPeers.Add(serverID, challenge)
}
}

sealConfig, err := b.Core.seal.BarrierConfig(ctx)
Expand All @@ -308,7 +330,7 @@ func (b *SystemBackend) handleRaftBootstrapChallengeWrite(makeSealer func() snap

return &logical.Response{
Data: map[string]interface{}{
"challenge": base64.StdEncoding.EncodeToString(protoBlob),
"challenge": base64.StdEncoding.EncodeToString(challenge.challenge),
"seal_config": sealConfig,
},
}, nil
Expand All @@ -330,6 +352,9 @@ func (b *SystemBackend) handleRaftBootstrapAnswerWrite() framework.OperationFunc
if len(answerRaw) == 0 {
return logical.ErrorResponse("no answer provided"), logical.ErrInvalidRequest
}
if len(answerRaw) > answerMaxEncodedSize {
return logical.ErrorResponse("answer is too long"), logical.ErrInvalidRequest
}
clusterAddr := d.Get("cluster_addr").(string)
if len(clusterAddr) == 0 {
return logical.ErrorResponse("no cluster_addr provided"), logical.ErrInvalidRequest
Expand All @@ -342,14 +367,17 @@ func (b *SystemBackend) handleRaftBootstrapAnswerWrite() framework.OperationFunc
return logical.ErrorResponse("could not base64 decode answer"), logical.ErrInvalidRequest
}

expectedAnswerRaw, ok := b.Core.pendingRaftPeers.Load(serverID)
b.Core.pendingRaftPeersLock.Lock()
expectedChallenge, ok := b.Core.pendingRaftPeers.Get(serverID)
if !ok {
b.Core.pendingRaftPeersLock.Unlock()
return logical.ErrorResponse("no expected answer for the server id provided"), logical.ErrInvalidRequest
}

b.Core.pendingRaftPeers.Delete(serverID)
b.Core.pendingRaftPeers.Remove(serverID)
b.Core.pendingRaftPeersLock.Unlock()

if subtle.ConstantTimeCompare(answer, expectedAnswerRaw.([]byte)) == 0 {
if subtle.ConstantTimeCompare(answer, expectedChallenge.answer) == 0 {
return logical.ErrorResponse("invalid answer given"), logical.ErrInvalidRequest
}

Expand Down
Loading
Loading