Skip to content

Commit

Permalink
Ensure installSnapshot consume stream. fixes issue hashicorp#212
Browse files Browse the repository at this point in the history
  • Loading branch information
superfell committed Jun 7, 2017
1 parent 939ebd2 commit ec99ca3
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 30 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ test:
go test -timeout=60s ./...

integ: test
INTEG_TESTS=yes go test -timeout=5s -run=Integ ./...
INTEG_TESTS=yes go test -timeout=25s -run=Integ ./...

deps:
go get -d -v ./...
Expand Down
127 changes: 98 additions & 29 deletions integ_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,34 @@ type RaftEnv struct {
}

func (r *RaftEnv) Release() {
r.logger.Printf("[WARN] Release node at %v", r.raft.localAddr)
r.Shutdown()
os.RemoveAll(r.dir)
}

// Shutdown shuts down raft & transport, but keeps track of its data, its restartable
// after a Shutdown() by calling Start()
func (r *RaftEnv) Shutdown() {
r.logger.Printf("[WARN] Shutdown node at %v", r.raft.localAddr)
f := r.raft.Shutdown()
if err := f.Error(); err != nil {
panic(err)
}
r.trans.Close()
os.RemoveAll(r.dir)
}

// Restart will start a raft node that was previously Shutdown()
func (r *RaftEnv) Restart(t *testing.T) {
trans, err := NewTCPTransport(string(r.raft.localAddr), nil, 2, time.Second, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
r.trans = trans
r.logger.Printf("[INFO] Starting node at %v", trans.LocalAddr())
raft, err := NewRaft(r.conf, r.fsm, r.store, r.store, r.snapshot, r.trans)
if err != nil {
t.Fatalf("err: %v", err)
}
r.raft = raft
}

func MakeRaft(t *testing.T, conf *Config, bootstrap bool) *RaftEnv {
Expand Down Expand Up @@ -69,11 +90,11 @@ func MakeRaft(t *testing.T, conf *Config, bootstrap bool) *RaftEnv {
fsm: &MockFSM{},
logger: log.New(&testLoggerAdapter{t: t}, "", log.Lmicroseconds),
}

trans, err := NewTCPTransport("127.0.0.1:0", nil, 2, time.Second, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
env.logger = log.New(os.Stdout, string(trans.LocalAddr())+" :", log.Lmicroseconds)
env.trans = trans

if bootstrap {
Expand All @@ -90,6 +111,7 @@ func MakeRaft(t *testing.T, conf *Config, bootstrap bool) *RaftEnv {
}

log.Printf("[INFO] Starting node at %v", trans.LocalAddr())
conf.Logger = env.logger
raft, err := NewRaft(conf, env.fsm, stable, stable, snap, trans)
if err != nil {
t.Fatalf("err: %v", err)
Expand Down Expand Up @@ -144,33 +166,51 @@ func NoErr(err error, t *testing.T) {
func CheckConsistent(envs []*RaftEnv, t *testing.T) {
limit := time.Now().Add(400 * time.Millisecond)
first := envs[0]
first.fsm.Lock()
defer first.fsm.Unlock()
var err error
CHECK:
l1 := len(first.fsm.logs)
for i := 1; i < len(envs); i++ {
env := envs[i]
env.fsm.Lock()
l2 := len(env.fsm.logs)
if l1 != l2 {
err = fmt.Errorf("log length mismatch %d %d", l1, l2)
env.fsm.Unlock()
goto ERR
}
for idx, log := range first.fsm.logs {
other := env.fsm.logs[idx]
if bytes.Compare(log, other) != 0 {
err = fmt.Errorf("log %d mismatch %v %v", idx, log, other)
err = fmt.Errorf("log entry %d mismatch between %s/%s : '%s' / '%s'", idx, first.raft.localAddr, env.raft.localAddr, log, other)
env.fsm.Unlock()
goto ERR
}
}
env.fsm.Unlock()
}
return
ERR:
if time.Now().After(limit) {
t.Fatalf("%v", err)
}
first.fsm.Unlock()
time.Sleep(20 * time.Millisecond)
first.fsm.Lock()
goto CHECK
}

// return a log entry that's at least sz long that has the prefix 'test i '
func logBytes(i, sz int) []byte {
var logBuffer bytes.Buffer
fmt.Fprintf(&logBuffer, "test %d ", i)
for logBuffer.Len() < sz {
logBuffer.WriteByte('x')
}
return logBuffer.Bytes()
}

// Tests Raft by creating a cluster, growing it to 5 nodes while
// causing various stressful conditions
func TestRaft_Integ(t *testing.T) {
Expand All @@ -188,15 +228,21 @@ func TestRaft_Integ(t *testing.T) {
env1 := MakeRaft(t, conf, true)
NoErr(WaitFor(env1, Leader), t)

// Do some commits
var futures []Future
for i := 0; i < 100; i++ {
futures = append(futures, env1.raft.Apply([]byte(fmt.Sprintf("test%d", i)), 0))
}
for _, f := range futures {
NoErr(WaitFuture(f, t), t)
env1.logger.Printf("[DEBUG] Applied %v", f)
totalApplied := 0
applyAndWait := func(leader *RaftEnv, n, sz int) {
// Do some commits
var futures []ApplyFuture
for i := 0; i < n; i++ {
futures = append(futures, leader.raft.Apply(logBytes(i, sz), 0))
}
for _, f := range futures {
NoErr(WaitFuture(f, t), t)
leader.logger.Printf("[DEBUG] Applied at %d, size %d", f.Index(), sz)
}
totalApplied += n
}
// Do some commits
applyAndWait(env1, 100, 10)

// Do a snapshot
NoErr(WaitFuture(env1.raft.Snapshot(), t), t)
Expand All @@ -216,15 +262,41 @@ func TestRaft_Integ(t *testing.T) {
NoErr(err, t)

// Do some more commits
futures = nil
for i := 0; i < 100; i++ {
futures = append(futures, leader.raft.Apply([]byte(fmt.Sprintf("test%d", i)), 0))
}
for _, f := range futures {
NoErr(WaitFuture(f, t), t)
leader.logger.Printf("[DEBUG] Applied %v", f)
applyAndWait(leader, 100, 10)

// Snapshot the leader
NoErr(WaitFuture(leader.raft.Snapshot(), t), t)
CheckConsistent(append([]*RaftEnv{env1}, envs...), t)

// shutdown a follower
disconnected := envs[len(envs)-1]
disconnected.Shutdown()

// Do some more commits [make sure the resulting snapshot will be a reasonable size]
applyAndWait(leader, 100, 10000)

// snapshot the leader [leaders log should be compacted past the disconnected follower log now]
NoErr(WaitFuture(leader.raft.Snapshot(), t), t)

// Unfortuantly we need to wait for the leader to start backing off RPCs to the down follower
// such that when the follower comes back up it'll run an election before it gets an rpc from
// the leader
time.Sleep(time.Second * 5)

// start the now out of date follower back up
disconnected.Restart(t)

// wait for it to get caught up
timeout := time.Now().Add(time.Second * 10)
for disconnected.raft.getLastApplied() < leader.raft.getLastApplied() {
time.Sleep(time.Millisecond)
if time.Now().After(timeout) {
t.Fatalf("Gave up waiting for follower to get caught up to leader")
}
}

CheckConsistent(append([]*RaftEnv{env1}, envs...), t)

// Shoot two nodes in the head!
rm1, rm2 := envs[0], envs[1]
rm1.Release()
Expand All @@ -237,14 +309,7 @@ func TestRaft_Integ(t *testing.T) {
NoErr(err, t)

// Do some more commits
futures = nil
for i := 0; i < 100; i++ {
futures = append(futures, leader.raft.Apply([]byte(fmt.Sprintf("test%d", i)), 0))
}
for _, f := range futures {
NoErr(WaitFuture(f, t), t)
leader.logger.Printf("[DEBUG] Applied %v", f)
}
applyAndWait(leader, 100, 10)

// Join a few new nodes!
for i := 0; i < 2; i++ {
Expand All @@ -255,6 +320,10 @@ func TestRaft_Integ(t *testing.T) {
envs = append(envs, env)
}

// Wait for a leader
leader, err = WaitForAny(Leader, append([]*RaftEnv{env1}, envs...))
NoErr(err, t)

// Remove the old nodes
NoErr(WaitFuture(leader.raft.RemoveServer(rm1.raft.localID, 0, 0), t), t)
NoErr(WaitFuture(leader.raft.RemoveServer(rm2.raft.localID, 0, 0), t), t)
Expand All @@ -270,8 +339,8 @@ func TestRaft_Integ(t *testing.T) {
allEnvs := append([]*RaftEnv{env1}, envs...)
CheckConsistent(allEnvs, t)

if len(env1.fsm.logs) != 300 {
t.Fatalf("should apply 300 logs! %d", len(env1.fsm.logs))
if len(env1.fsm.logs) != totalApplied {
t.Fatalf("should apply %d logs! %d", totalApplied, len(env1.fsm.logs))
}

for _, e := range envs {
Expand Down
3 changes: 3 additions & 0 deletions raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"container/list"
"fmt"
"io"
"io/ioutil"
"time"

"github.com/armon/go-metrics"
Expand Down Expand Up @@ -1238,6 +1239,7 @@ func (r *Raft) installSnapshot(rpc RPC, req *InstallSnapshotRequest) {
}
var rpcErr error
defer func() {
io.Copy(ioutil.Discard, rpc.Reader) // ensure we always consume all the snapshot data from the stream [see issue #212]
rpc.Respond(resp, rpcErr)
}()

Expand All @@ -1250,6 +1252,7 @@ func (r *Raft) installSnapshot(rpc RPC, req *InstallSnapshotRequest) {

// Ignore an older term
if req.Term < r.getCurrentTerm() {
r.logger.Printf("[INFO] raft: Ignoring installSnapshot request with older term of %d vs currentTerm %d", req.Term, r.getCurrentTerm())
return
}

Expand Down

0 comments on commit ec99ca3

Please sign in to comment.