Skip to content

Commit

Permalink
Merge pull request #57 from vearne/fix/wraparound
Browse files Browse the repository at this point in the history
fix seq wrap around
  • Loading branch information
vearne authored Nov 8, 2024
2 parents 88058c9 + dfd8468 commit b0af2e4
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 76 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VERSION := v0.2.3
VERSION := v0.2.4

BIN_NAME = grpcr
CONTAINER = grpcr
Expand Down
9 changes: 5 additions & 4 deletions http2/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/dynamicpb"
"math"

"github.com/vearne/grpcreplay/protocol"
slog "github.com/vearne/simplelog"
Expand Down Expand Up @@ -49,23 +50,23 @@ func (p *Processor) ProcessTCPPkg() {
}
hc := p.ConnRepository[dc]

payloadSize := uint32(len(payload))

// SYN/ACK/FIN
if len(payload) <= 0 {
if pkg.TCP.FIN {
slog.Info("got Fin package, close connection:%v", dc.String())
hc.TCPBuffer.Close()
delete(p.ConnRepository, dc)
} else {
hc.TCPBuffer.expectedSeq = int64(pkg.TCP.Seq) + int64(len(pkg.TCP.Payload))
hc.TCPBuffer.leftPointer = hc.TCPBuffer.expectedSeq
hc.TCPBuffer.expectedSeq = (pkg.TCP.Seq + payloadSize) % math.MaxUint32
}
continue
}

// connection preface
if IsConnPreface(payload) {
hc.TCPBuffer.expectedSeq = int64(pkg.TCP.Seq) + int64(len(pkg.TCP.Payload))
hc.TCPBuffer.leftPointer = hc.TCPBuffer.expectedSeq
hc.TCPBuffer.expectedSeq = (pkg.TCP.Seq + payloadSize) % math.MaxUint32
continue
}

Expand Down
88 changes: 26 additions & 62 deletions http2/tcp_buffer.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
package http2

import (
"bytes"
"github.com/google/gopacket/layers"
"github.com/huandu/skiplist"
slog "github.com/vearne/simplelog"
"math"
"net"
"sync/atomic"
)

type TCPBuffer struct {
//The number of bytes of data currently cached
size uint32
actualCanReadSize uint32
size atomic.Int64
actualCanReadSize atomic.Int64
List *skiplist.SkipList
expectedSeq int64
// The sliding window contains the leftPointer
leftPointer int64

expectedSeq uint32
//There is at most one reader to read
dataChannel chan []byte
closeChan chan struct{}
Expand All @@ -26,11 +23,10 @@ type TCPBuffer struct {
func NewTCPBuffer() *TCPBuffer {
var sb TCPBuffer
sb.List = skiplist.New(skiplist.Uint32)
sb.size = 0
sb.actualCanReadSize = 0
sb.expectedSeq = -1
sb.leftPointer = -1
sb.dataChannel = make(chan []byte, 10)
sb.size.Store(0)
sb.actualCanReadSize.Store(0)
sb.expectedSeq = 0
sb.dataChannel = make(chan []byte, 100)
sb.closeChan = make(chan struct{})
return &sb
}
Expand All @@ -47,72 +43,41 @@ func (sb *TCPBuffer) Read(p []byte) (n int, err error) {
err = net.ErrClosed
case data = <-sb.dataChannel:
n = copy(p, data)
dataSize := int64(len(data))
sb.size.Add(dataSize * -1)
sb.actualCanReadSize.Add(dataSize * -1)
}
slog.Debug("SocketBuffer.Read, got:%v bytes", n)
return n, err
}

func (sb *TCPBuffer) AddTCP(tcpPkg *layers.TCP) {
sb.addTCP(tcpPkg)

if sb.actualCanReadSize > 0 {
slog.Debug("SocketBuffer.AddTCP, satisfy the conditions, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize, sb.expectedSeq)
data := sb.getData()
slog.Debug("push to channel: %v bytes", len(data))
sb.dataChannel <- data
}
}

func (sb *TCPBuffer) addTCP(tcpPkg *layers.TCP) {
slog.Debug("[start]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize, sb.expectedSeq)
sb.size.Load(), sb.actualCanReadSize.Load(), sb.expectedSeq)

// duplicate package
if int64(tcpPkg.Seq) < sb.leftPointer || sb.List.Get(tcpPkg.Seq) != nil {
if sb.List.Get(tcpPkg.Seq) != nil {
slog.Debug("[end]SocketBuffer.addTCP-duplicate package, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize, sb.expectedSeq)
sb.size.Load(), sb.actualCanReadSize.Load(), sb.expectedSeq)
return
}

ele := sb.List.Set(tcpPkg.Seq, tcpPkg)
sb.size += uint32(len(tcpPkg.Payload))
sb.size.Add(int64(len(tcpPkg.Payload)))
needRemoveList := make([]*skiplist.Element, 0)

for ele != nil && sb.expectedSeq == int64(tcpPkg.Seq) {
for ele != nil && sb.expectedSeq == tcpPkg.Seq {
// expect next sequence number
sb.expectedSeq = int64((tcpPkg.Seq + uint32(len(tcpPkg.Payload))) % math.MaxUint32)
sb.actualCanReadSize += uint32(len(tcpPkg.Payload))
// sequence numbers may wrap around
payloadSize := uint32(len(tcpPkg.Payload))
sb.actualCanReadSize.Add(int64(payloadSize))
sb.expectedSeq = (tcpPkg.Seq + payloadSize) % math.MaxUint32

ele = ele.Next()
if ele != nil {
tcpPkg = ele.Value.(*layers.TCP)
}
}
slog.Debug("[end]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize, sb.expectedSeq)
}

func (sb *TCPBuffer) getData() []byte {
slog.Debug("[start]SocketBuffer.getData, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size, sb.actualCanReadSize, sb.expectedSeq)

var tcpPkg *layers.TCP
buf := bytes.NewBuffer([]byte{})
ele := sb.List.Front()
if ele != nil {
tcpPkg = ele.Value.(*layers.TCP)
}

needRemoveList := make([]*skiplist.Element, 0)
for ele != nil && int64(tcpPkg.Seq) <= sb.expectedSeq {
sb.actualCanReadSize -= uint32(len(tcpPkg.Payload))
sb.size -= uint32(len(tcpPkg.Payload))
sb.leftPointer += int64(len(tcpPkg.Payload))

buf.Write(tcpPkg.Payload)
// push to channel
sb.dataChannel <- tcpPkg.Payload
needRemoveList = append(needRemoveList, ele)

ele = ele.Next()
ele = sb.List.Get(sb.expectedSeq)
if ele != nil {
tcpPkg = ele.Value.(*layers.TCP)
}
Expand All @@ -123,7 +88,6 @@ func (sb *TCPBuffer) getData() []byte {
sb.List.RemoveElement(element)
}

slog.Debug("[end]SocketBuffer.getData, size:%v, actualCanReadSize:%v, expectedSeq:%v, data: %v bytes",
sb.size, sb.actualCanReadSize, sb.expectedSeq, buf.Len())
return buf.Bytes()
slog.Debug("[end]SocketBuffer.addTCP, size:%v, actualCanReadSize:%v, expectedSeq:%v",
sb.size.Load(), sb.actualCanReadSize.Load(), sb.expectedSeq)
}
75 changes: 66 additions & 9 deletions http2/tcp_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ func TestSocketBufferSequence1(t *testing.T) {
slog.SetLevel(slog.DebugLevel)
buffer := NewTCPBuffer()
buffer.expectedSeq = 1000
buffer.leftPointer = 1000

var tcpPkgA layers.TCP
tcpPkgA.Seq = 1000
Expand Down Expand Up @@ -44,7 +43,6 @@ func TestSocketBufferSequence2(t *testing.T) {
slog.SetLevel(slog.DebugLevel)
buffer := NewTCPBuffer()
buffer.expectedSeq = 1000
buffer.leftPointer = 1000

var tcpPkgA layers.TCP
tcpPkgA.Seq = 1000
Expand Down Expand Up @@ -80,7 +78,6 @@ func TestSocketBufferSequence3(t *testing.T) {
slog.SetLevel(slog.DebugLevel)
buffer := NewTCPBuffer()
buffer.expectedSeq = 1000
buffer.leftPointer = 1000

var tcpPkgA layers.TCP
tcpPkgA.Seq = 1000
Expand Down Expand Up @@ -112,22 +109,21 @@ func TestSocketBufferSequence3(t *testing.T) {
assert.Nil(t, err)
}

func TestSocketBufferDuplicate(t *testing.T) {
func TestSocketBufferWrapAround1(t *testing.T) {
slog.SetLevel(slog.DebugLevel)
buffer := NewTCPBuffer()
buffer.expectedSeq = 1000
buffer.leftPointer = 1000
buffer.expectedSeq = 4294967290

var tcpPkgA layers.TCP
tcpPkgA.Seq = 1000
tcpPkgA.Seq = 4294967290
tcpPkgA.Payload = []byte("aaaaaaaaaa")

var tcpPkgB layers.TCP
tcpPkgB.Seq = 1010
tcpPkgB.Seq = 4
tcpPkgB.Payload = []byte("bbbbbbbbbb")

var tcpPkgC layers.TCP
tcpPkgC.Seq = 1020
tcpPkgC.Seq = 14
tcpPkgC.Payload = []byte("cccccccccc")

buffer.AddTCP(&tcpPkgA)
Expand All @@ -143,3 +139,64 @@ func TestSocketBufferDuplicate(t *testing.T) {
// assert for nil (good for errors)
assert.Nil(t, err)
}

func TestSocketBufferWrapAround2(t *testing.T) {
slog.SetLevel(slog.DebugLevel)
buffer := NewTCPBuffer()
buffer.expectedSeq = 4294967290

var tcpPkgA layers.TCP
tcpPkgA.Seq = 4294967290
tcpPkgA.Payload = []byte("aaaaaaaaaa")

var tcpPkgB layers.TCP
tcpPkgB.Seq = 4
tcpPkgB.Payload = []byte("bbbbbbbbbb")

var tcpPkgC layers.TCP
tcpPkgC.Seq = 14
tcpPkgC.Payload = []byte("cccccccccc")

buffer.AddTCP(&tcpPkgB)
buffer.AddTCP(&tcpPkgA)
buffer.AddTCP(&tcpPkgC)
buffer.AddTCP(&tcpPkgA)

buf := make([]byte, 1024)
n, err := io.ReadAtLeast(buffer, buf, 30)
// assert equality
assert.Equal(t, 30, n, "read data")
assert.Equal(t, "aaaaaaaaaabbbbbbbbbbcccccccccc", string(buf[0:n]), "read data")
// assert for nil (good for errors)
assert.Nil(t, err)
}

func TestSocketBufferWrapAround3(t *testing.T) {
slog.SetLevel(slog.DebugLevel)
buffer := NewTCPBuffer()
buffer.expectedSeq = 4294967290

var tcpPkgA layers.TCP
tcpPkgA.Seq = 4294967290
tcpPkgA.Payload = []byte("aaaaaaaaaa")

var tcpPkgB layers.TCP
tcpPkgB.Seq = 4
tcpPkgB.Payload = []byte("bbbbbbbbbb")

var tcpPkgC layers.TCP
tcpPkgC.Seq = 14
tcpPkgC.Payload = []byte("cccccccccc")

buffer.AddTCP(&tcpPkgA)
buffer.AddTCP(&tcpPkgB)
buffer.AddTCP(&tcpPkgC)

buf := make([]byte, 1024)
n, err := io.ReadAtLeast(buffer, buf, 30)
// assert equality
assert.Equal(t, 30, n, "read data")
assert.Equal(t, "aaaaaaaaaabbbbbbbbbbcccccccccc", string(buf[0:n]), "read data")
// assert for nil (good for errors)
assert.Nil(t, err)
}

0 comments on commit b0af2e4

Please sign in to comment.