From 35a5529f44f8e3938209e07404eccde0b88ea87c Mon Sep 17 00:00:00 2001 From: Ain Ghazal <99027643+ainghazal@users.noreply.github.com> Date: Thu, 8 Feb 2024 18:45:20 +0100 Subject: [PATCH] tests: add tests for reliable service under packet loss (#63) Exercise reliable service passing a vector of packet ID to simulate packet loss. Additionally, this PR updates github actions version to the latest action release for checkout (v4) and setup-go (v5). Eventually, we'll currently replace the test setup with the one we're using to run tests on the internal packages. Co-authored-by: Simone Basso --- .github/workflows/build-refactor.yml | 23 +- Makefile | 9 +- internal/reliabletransport/packets.go | 4 +- internal/reliabletransport/receiver.go | 2 +- internal/reliabletransport/receiver_test.go | 8 +- .../reliabletransport/reliable_ack_test.go | 34 +- .../reliabletransport/reliable_loss_test.go | 273 +++++++++++++ .../reliable_reorder_test.go | 15 +- internal/reliabletransport/sender_test.go | 8 +- internal/reliabletransport/service_test.go | 2 +- internal/vpntest/doc.go | 2 + internal/vpntest/packetio.go | 308 +++++++++++++- internal/vpntest/packetio_test.go | 376 ++++++++++++++++++ internal/vpntest/vpntest.go | 6 +- 14 files changed, 1007 insertions(+), 63 deletions(-) create mode 100644 internal/reliabletransport/reliable_loss_test.go create mode 100644 internal/vpntest/doc.go diff --git a/.github/workflows/build-refactor.yml b/.github/workflows/build-refactor.yml index a822bf03..c3751ee9 100644 --- a/.github/workflows/build-refactor.yml +++ b/.github/workflows/build-refactor.yml @@ -13,21 +13,30 @@ jobs: short-tests: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: setup go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: '1.21' - name: Run short tests run: go test --short -cover ./internal/... + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Lint with revive action, from pre-built image + uses: docker://morphy/revive-action:v2 + with: + path: "internal/..." + gosec: runs-on: ubuntu-latest env: GO111MODULE: on steps: - name: Checkout Source - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Run Gosec security scanner uses: securego/gosec@master with: @@ -36,9 +45,9 @@ jobs: coverage-threshold: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: setup go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: '1.21' - name: Ensure coverage threshold @@ -47,9 +56,9 @@ jobs: integration: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: setup go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: '1.21' - name: run integration tests diff --git a/Makefile b/Makefile index 881d2423..376d521f 100644 --- a/Makefile +++ b/Makefile @@ -99,16 +99,19 @@ netns-shell: sudo ip netns exec protected sudo -u `whoami` -i .PHONY: lint -lint: go-fmt go-vet go-sec +lint: go-fmt go-vet go-sec go-revive go-fmt: gofmt -s -l . go-vet: - go vet ./... + go vet internal/... go-sec: - gosec ./... + gosec internal/... + +go-revive: + revive internal/... clean: @rm -f coverage.out diff --git a/internal/reliabletransport/packets.go b/internal/reliabletransport/packets.go index 5fc34539..1cb3991a 100644 --- a/internal/reliabletransport/packets.go +++ b/internal/reliabletransport/packets.go @@ -33,11 +33,11 @@ func newInFlightPacket(p *model.Packet) *inFlightPacket { // ACKForHigherPacket increments the number of acks received for a higher pid than this packet. This will influence the fast rexmit selection algorithm. func (p *inFlightPacket) ACKForHigherPacket() { - p.higherACKs += 1 + p.higherACKs++ } func (p *inFlightPacket) ScheduleForRetransmission(t time.Time) { - p.retries += 1 + p.retries++ p.deadline = t.Add(p.backoff()) } diff --git a/internal/reliabletransport/receiver.go b/internal/reliabletransport/receiver.go index 4065f54b..9c2a927b 100644 --- a/internal/reliabletransport/receiver.go +++ b/internal/reliabletransport/receiver.go @@ -135,7 +135,7 @@ func (r *reliableReceiver) NextIncomingSequence() incomingSequence { for i, p := range r.incomingPackets { if p.ID-last == 1 { ready = append(ready, p) - last += 1 + last++ } else if p.ID > last { // here we broke sequentiality, but we want // to drop anything that is below lastConsumed diff --git a/internal/reliabletransport/receiver_test.go b/internal/reliabletransport/receiver_test.go index b5083201..0d9c49b4 100644 --- a/internal/reliabletransport/receiver_test.go +++ b/internal/reliabletransport/receiver_test.go @@ -26,7 +26,9 @@ func Test_newReliableReceiver(t *testing.T) { } func Test_reliableQueue_MaybeInsertIncoming(t *testing.T) { - log.SetLevel(log.DebugLevel) + if testing.Verbose() { + log.SetLevel(log.DebugLevel) + } type fields struct { incomingPackets incomingSequence @@ -91,7 +93,9 @@ func Test_reliableQueue_MaybeInsertIncoming(t *testing.T) { } func Test_reliableQueue_NextIncomingSequence(t *testing.T) { - log.SetLevel(log.DebugLevel) + if testing.Verbose() { + log.SetLevel(log.DebugLevel) + } type fields struct { lastConsumed model.PacketID diff --git a/internal/reliabletransport/reliable_ack_test.go b/internal/reliabletransport/reliable_ack_test.go index fe48048b..59767974 100644 --- a/internal/reliabletransport/reliable_ack_test.go +++ b/internal/reliabletransport/reliable_ack_test.go @@ -11,9 +11,21 @@ import ( ) // test that everything that is received from below is eventually ACKed to the sender. +/* + + ┌────┐id ┌────┐ + │sndr│◄──┤rcvr│ + └─┬──┘ └──▲─┘ + │ │ + │ │ + │ │ + ▼ send + ack +*/ func TestReliable_ACK(t *testing.T) { - - log.SetLevel(log.DebugLevel) + if testing.Verbose() { + log.SetLevel(log.DebugLevel) + } type args struct { inputSequence []string @@ -91,22 +103,6 @@ func TestReliable_ACK(t *testing.T) { wantacks: 5, }, }, - /* - { - name: "a burst of packets", - args: args{ - inputSequence: []string{ - "[5] CONTROL_V1 +1ms", - "[1] CONTROL_V1 +1ms", - "[3] CONTROL_V1 +1ms", - "[2] CONTROL_V1 +1ms", - "[4] CONTROL_V1 +1ms", - }, - start: 1, - wantacks: 5, - }, - }, - */ } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -148,7 +144,7 @@ func TestReliable_ACK(t *testing.T) { reader := vpntest.NewPacketReader(dataOut) witness := vpntest.NewWitness(reader) - if ok := witness.VerifyNumberOfACKs(tt.args.start, tt.args.wantacks, t0); !ok { + if ok := witness.VerifyNumberOfACKs(tt.args.wantacks, t0); !ok { got := len(witness.Log().ACKs()) t.Errorf("TestACK: got = %v, want %v", got, tt.args.wantacks) } diff --git a/internal/reliabletransport/reliable_loss_test.go b/internal/reliabletransport/reliable_loss_test.go new file mode 100644 index 00000000..7b429bf7 --- /dev/null +++ b/internal/reliabletransport/reliable_loss_test.go @@ -0,0 +1,273 @@ +package reliabletransport + +import ( + "testing" + "time" + + "github.com/apex/log" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/vpntest" +) + +// test that everything that is sent eventually arrives in bounded time, in the pressence of losses. +/* + │ + │ + ▼ + ┌────┐ack┌────┐ + │sndr│◄──┤rcvr│ + └─┬──┘ └──▲─┘ + │ │ +drop◄─┤ │ + │ │ + ▼relay (ack) +*/ +func TestReliable_WithLoss(t *testing.T) { + if testing.Verbose() { + log.SetLevel(log.DebugLevel) + } + + type args struct { + inputSequence []string + inputPayload string + want string + losses []int + } + + tests := []struct { + name string + args args + }{ + // do note that all of the test cases below are using + // unrealistic timing and fast-retransmit (since we're very quickly + // acking a bunch of packets above them) + { + name: "ten ordered packets with no loss", + args: args{ + inputSequence: []string{ + "[1] CONTROL_V1 +1ms", + "[2] CONTROL_V1 +1ms", + "[3] CONTROL_V1 +1ms", + "[4] CONTROL_V1 +1ms", + "[5] CONTROL_V1 +1ms", + "[6] CONTROL_V1 +1ms", + "[7] CONTROL_V1 +1ms", + "[8] CONTROL_V1 +1ms", + "[9] CONTROL_V1 +1ms", + "[10] CONTROL_V1 +1ms", + }, + inputPayload: "aaabbbcccdddeeefffggghhhiiijjj", + want: "aaabbbcccdddeeefffggghhhiiijjj", + losses: []int{}, + }, + }, + { + name: "ten ordered packets, first loss", + args: args{ + inputSequence: []string{ + "[1] CONTROL_V1 +10ms", + "[2] CONTROL_V1 +10ms", + "[3] CONTROL_V1 +10ms", + "[4] CONTROL_V1 +10ms", + "[5] CONTROL_V1 +10ms", + "[6] CONTROL_V1 +10ms", + "[7] CONTROL_V1 +10ms", + "[8] CONTROL_V1 +10ms", + "[9] CONTROL_V1 +10ms", + "[10] CONTROL_V1 +10ms", + }, + inputPayload: "aaabbbcccdddeeefffggghhhiiijjj", + want: "aaabbbcccdddeeefffggghhhiiijjj", + losses: []int{1}, + }, + }, + { + name: "ten ordered packets, 1,3,5,7 loss", + args: args{ + inputSequence: []string{ + "[1] CONTROL_V1 +10ms", + "[2] CONTROL_V1 +10ms", + "[3] CONTROL_V1 +10ms", + "[4] CONTROL_V1 +10ms", + "[5] CONTROL_V1 +10ms", + "[6] CONTROL_V1 +10ms", + "[7] CONTROL_V1 +10ms", + "[8] CONTROL_V1 +10ms", + "[9] CONTROL_V1 +10ms", + "[10] CONTROL_V1 +10ms", + }, + inputPayload: "aaabbbcccdddeeefffggghhhiiijjj", + want: "aaabbbcccdddeeefffggghhhiiijjj", + losses: []int{1, 3, 5, 7}, + }, + }, + { + name: "ten ordered packets, 2,4,6,8 loss", + args: args{ + inputSequence: []string{ + "[1] CONTROL_V1 +10ms", + "[2] CONTROL_V1 +10ms", + "[3] CONTROL_V1 +10ms", + "[4] CONTROL_V1 +10ms", + "[5] CONTROL_V1 +10ms", + "[6] CONTROL_V1 +10ms", + "[7] CONTROL_V1 +10ms", + "[8] CONTROL_V1 +10ms", + "[9] CONTROL_V1 +10ms", + "[10] CONTROL_V1 +10ms", + }, + inputPayload: "aaabbbcccdddeeefffggghhhiiijjj", + want: "aaabbbcccdddeeefffggghhhiiijjj", + losses: []int{2, 4, 6, 8}, + }, + }, + { + name: "ten ordered packets, 2-5 loss, 2 lost again", + args: args{ + inputSequence: []string{ + "[1] CONTROL_V1 +10ms", + "[2] CONTROL_V1 +10ms", + "[3] CONTROL_V1 +10ms", + "[4] CONTROL_V1 +10ms", + "[5] CONTROL_V1 +10ms", + "[6] CONTROL_V1 +10ms", + "[7] CONTROL_V1 +10ms", + "[8] CONTROL_V1 +10ms", + "[9] CONTROL_V1 +10ms", + "[10] CONTROL_V1 +10ms", + }, + inputPayload: "aaabbbcccdddeeefffggghhhiiijjj", + want: "aaabbbcccdddeeefffggghhhiiijjj", + losses: []int{2, 3, 4, 5, 2}, + }, + }, + { + name: "ten out-of-order packets", + args: args{ + inputSequence: []string{ + "[6] CONTROL_V1 +10ms", + "[3] CONTROL_V1 +10ms", + "[1] CONTROL_V1 +10ms", + "[2] CONTROL_V1 +10ms", + "[4] CONTROL_V1 +10ms", + "[5] CONTROL_V1 +10ms", + "[7] CONTROL_V1 +10ms", + "[8] CONTROL_V1 +10ms", + "[9] CONTROL_V1 +10ms", + "[10] CONTROL_V1 +10ms", + }, + inputPayload: "fffcccaaabbbdddeeeggghhhiiijjj", + want: "aaabbbcccdddeeefffggghhhiiijjj", + losses: []int{}, + }, + }, + { + name: "ten out-of-order packets, loss=1,5", + args: args{ + inputSequence: []string{ + "[6] CONTROL_V1 +10ms", + "[3] CONTROL_V1 +10ms", + "[1] CONTROL_V1 +10ms", + "[2] CONTROL_V1 +10ms", + "[4] CONTROL_V1 +10ms", + "[5] CONTROL_V1 +10ms", + "[7] CONTROL_V1 +10ms", + "[8] CONTROL_V1 +10ms", + "[9] CONTROL_V1 +10ms", + "[10] CONTROL_V1 +10ms", + }, + inputPayload: "fffcccaaabbbdddeeeggghhhiiijjj", + want: "aaabbbcccdddeeefffggghhhiiijjj", + losses: []int{1, 5}, + }, + }, + + // TODO(ainghazal): exclude the following tests if not `-short`? + + { + name: "ten ordered packets, first lost 4 times", + args: args{ + inputSequence: []string{ + "[1] CONTROL_V1 +10ms", + "[2] CONTROL_V1 +10ms", + "[3] CONTROL_V1 +10ms", + "[4] CONTROL_V1 +10ms", + "[5] CONTROL_V1 +10ms", + "[6] CONTROL_V1 +10ms", + "[7] CONTROL_V1 +10ms", + "[8] CONTROL_V1 +10ms", + "[9] CONTROL_V1 +10ms", + "[10] CONTROL_V1 +10ms", + }, + inputPayload: "aaabbbcccdddeeefffggghhhiiijjj", + want: "aaabbbcccdddeeefffggghhhiiijjj", + losses: []int{1, 1, 1, 1}, + }, + }, + { + name: "arbitrary text", + args: args{ + inputSequence: []string{"[1..142] CONTROL_V1 +10ms"}, + inputPayload: "I think that the next two generations of Americans will be grappling with the very real specter of finding themselves living in a new and bizarre kind of digital totalitarian state - one that looks and feels democratic on the surface, but has a fierce undercurrent of fear and technologically enforced fascism any time you step out of line. I really hope this isn't the case, but it looks really bad right now, doesn't it?", + want: "I think that the next two generations of Americans will be grappling with the very real specter of finding themselves living in a new and bizarre kind of digital totalitarian state - one that looks and feels democratic on the surface, but has a fierce undercurrent of fear and technologically enforced fascism any time you step out of line. I really hope this isn't the case, but it looks really bad right now, doesn't it?", + losses: []int{1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Service{} + + // where we write stuff (simulates control channel output) + dataIn := make(chan *model.Packet, 1024) + + // where we want to read reordered stuff + dataOut := make(chan *model.Packet, 1024) + + s.ControlToReliable = dataIn + // this one up to control/tls also needs to be buffered because otherwise + // we'll block on the receiver when delivering up. + reliableToControl := dataOut + s.ReliableToControl = &reliableToControl + + // data out from reliable (downwards to network) + toMuxer := make(chan *model.Packet, 1024) + s.DataOrControlToMuxer = &toMuxer // down + + // this will be the data out after losses (simulates upwards to tls) + toNetwork := make(chan *model.Packet, 1024) + + // data in from network (up from muxer) + fromMuxer := make(chan *model.Packet, 1024) + s.MuxerToReliable = fromMuxer // up + + workers, session := initManagers() + + echoServer := vpntest.NewEchoServer(toNetwork, fromMuxer) + echoServer.RemoteSessionID = model.SessionID(session.LocalSessionID()) + session.SetRemoteSessionID(echoServer.LocalSessionID) + + t0 := time.Now() + + // let the workers pump up the jam! + s.StartWorkers(log.Log, workers, session) + + writer := vpntest.NewPacketWriter(dataIn) + go writer.WriteSequenceWithFixedPayload(tt.args.inputSequence, tt.args.inputPayload, 3) + + // start a relay to simulate losses + relay := vpntest.NewPacketRelay(toMuxer, toNetwork) + go relay.RelayWithLosses(tt.args.losses) + defer relay.Stop() + + // start the mock server that echoes payloads with sequenced packets and acks + go echoServer.Start() + defer echoServer.Stop() + + witness := vpntest.NewWitnessFromChannel(dataOut) + if ok := witness.VerifyOrderedPayload(tt.args.want, t0); !ok { + t.Errorf("TestLoss: payload does not match. got=%s, want=%s", witness.Payload(), tt.args.want) + } + }) + } +} diff --git a/internal/reliabletransport/reliable_reorder_test.go b/internal/reliabletransport/reliable_reorder_test.go index b4e94823..bd28a096 100644 --- a/internal/reliabletransport/reliable_reorder_test.go +++ b/internal/reliabletransport/reliable_reorder_test.go @@ -10,9 +10,20 @@ import ( ) // test that we're able to reorder (towards TLS) whatever is received (from the muxer). +// +// dataOut +// ▲ +// │ +// ┌────┐ ┌──┴─┐ +// │sndr│ │rcvr│ +// └────┘ └────┘ +// ▲ +// | +// dataIn func TestReliable_Reordering_UP(t *testing.T) { - - log.SetLevel(log.DebugLevel) + if testing.Verbose() { + log.SetLevel(log.DebugLevel) + } type args struct { inputSequence []string diff --git a/internal/reliabletransport/sender_test.go b/internal/reliabletransport/sender_test.go index 488703ff..eb25afb3 100644 --- a/internal/reliabletransport/sender_test.go +++ b/internal/reliabletransport/sender_test.go @@ -34,7 +34,9 @@ func Test_newReliableSender(t *testing.T) { } func Test_reliableSender_TryInsertOutgoingPacket(t *testing.T) { - log.SetLevel(log.DebugLevel) + if testing.Verbose() { + log.SetLevel(log.DebugLevel) + } type fields struct { inFlight inflightSequence @@ -106,7 +108,9 @@ func Test_reliableSender_TryInsertOutgoingPacket(t *testing.T) { } func Test_reliableSender_NextPacketIDsToACK(t *testing.T) { - log.SetLevel(log.DebugLevel) + if testing.Verbose() { + log.SetLevel(log.DebugLevel) + } type fields struct { pendingACKsToSend []model.PacketID diff --git a/internal/reliabletransport/service_test.go b/internal/reliabletransport/service_test.go index 6354d413..3555fd7f 100644 --- a/internal/reliabletransport/service_test.go +++ b/internal/reliabletransport/service_test.go @@ -52,7 +52,7 @@ func TestService_StartWorkers(t *testing.T) { }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(tt.name, func(_ *testing.T) { s := &Service{ DataOrControlToMuxer: tt.fields.DataOrControlToMuxer, ControlToReliable: tt.fields.ControlToReliable, diff --git a/internal/vpntest/doc.go b/internal/vpntest/doc.go new file mode 100644 index 00000000..6e910c6c --- /dev/null +++ b/internal/vpntest/doc.go @@ -0,0 +1,2 @@ +// Package vpntest provides utitities to facilitate testing different minivpn packages. +package vpntest diff --git a/internal/vpntest/packetio.go b/internal/vpntest/packetio.go index 923bfc84..1e53a03d 100644 --- a/internal/vpntest/packetio.go +++ b/internal/vpntest/packetio.go @@ -1,10 +1,15 @@ package vpntest import ( + "fmt" + "regexp" "slices" + "strconv" + "sync" "time" "github.com/apex/log" + "github.com/ooni/minivpn/internal/bytesx" "github.com/ooni/minivpn/internal/model" ) @@ -18,6 +23,9 @@ type PacketWriter struct { // RemoteSessionID is needed to produce ACKs. RemoteSessionID model.SessionID + + payload string + packetPayloadSize int } // NewPacketWriter creates a new PacketWriter. @@ -27,22 +35,97 @@ func NewPacketWriter(ch chan<- *model.Packet) *PacketWriter { // WriteSequence writes the passed packet sequence (in their string representation) // to the configured channel. It will wait the specified interval between one packet and the next. +// The input sequence strings will be expanded for range notation, as in [1..10] func (pw *PacketWriter) WriteSequence(seq []string) { - for _, testStr := range seq { - testPkt, err := NewTestPacketFromString(testStr) - if err != nil { - panic("PacketWriter: error reading test sequence:" + err.Error()) + for _, expr := range seq { + for _, item := range maybeExpand(expr) { + pw.writeSequenceItem(item) } + } +} + +// possibly expand a input sequence in range notation for the packet ids [1..10] +func maybeExpand(input string) []string { + fmt.Println("maybe expand") + items := []string{} + pattern := `^\[(\d+)\.\.(\d+)\] (.+)` + regexpPattern := regexp.MustCompile(pattern) + matches := regexpPattern.FindStringSubmatch(input) + if len(matches) != 4 { + // not a range, return the single element + items = append(items, input) + return items + } + + fmt.Println("len matches", len(matches)) + + // extract beginning and end of the range + fromStr := matches[1] + toStr := matches[2] + body := matches[3] + + // convert to int (from/to ) + from, err := strconv.Atoi(fromStr) + if err != nil { + panic(err) + } + + to, err := strconv.Atoi(toStr) + if err != nil { + panic(err) + } - p := &model.Packet{ - Opcode: testPkt.Opcode, - RemoteSessionID: pw.RemoteSessionID, - LocalSessionID: pw.LocalSessionID, - ID: model.PacketID(testPkt.ID), + // return the expanded id range + for i := from; i <= to; i++ { + items = append(items, fmt.Sprintf("[%d] %s", i, body)) + } + return items +} + +func (pw *PacketWriter) writeSequenceItem(item string) { + testPkt, err := NewTestPacketFromString(item) + if err != nil { + panic("PacketWriter: error reading test sequence:" + err.Error()) + } + p := &model.Packet{ + Opcode: testPkt.Opcode, + RemoteSessionID: pw.RemoteSessionID, + LocalSessionID: pw.LocalSessionID, + ID: model.PacketID(testPkt.ID), + } + if len(pw.payload) > 0 { + var payload, rest string + size := pw.packetPayloadSize + if len(pw.payload) < size { + payload = pw.payload + pw.payload = "" + } else { + payload, rest = pw.payload[:size], pw.payload[size:] + pw.payload = rest } - pw.ch <- p - time.Sleep(testPkt.IAT) + p.Payload = []byte(payload) } + pw.ch <- p + time.Sleep(testPkt.IAT) +} + +// WriteSequenceWithFixedPayload will write packets according to the sequence specified in seq, +// but will sequentially pick the payload from the passed payload string, in increments defined by size. +func (pw *PacketWriter) WriteSequenceWithFixedPayload(seq []string, payload string, size int) { + pw.payload = payload + pw.packetPayloadSize = size + pw.WriteSequence(seq) +} + +// WritePacketWithID writes a dummy control packet with the passed ID. +func (pw *PacketWriter) WritePacketWithID(i int) { + p := &model.Packet{ + Opcode: model.P_CONTROL_V1, + RemoteSessionID: pw.RemoteSessionID, + LocalSessionID: pw.LocalSessionID, + ID: model.PacketID(i), + } + pw.ch <- p } // LoggedPacket is a trace of a received packet. @@ -92,8 +175,9 @@ func (l PacketLog) ACKs() []int { // PacketReader reads packets from a channel. type PacketReader struct { - ch <-chan *model.Packet - log []*LoggedPacket + ch <-chan *model.Packet + log []*LoggedPacket + payload []byte } // NewPacketReader creates a new PacketReader. @@ -102,6 +186,11 @@ func NewPacketReader(ch <-chan *model.Packet) *PacketReader { return &PacketReader{ch: ch, log: logged} } +// Payload returns the string payload constructed from the payloads in the received packets. +func (pr *PacketReader) Payload() string { + return string(pr.payload) +} + // WaitForSequence loops reading from the internal channel until the logged // sequence matches the len of the expected sequence; it returns // true if the obtained packet ID sequence matches the expected one. @@ -112,14 +201,13 @@ func (pr *PacketReader) WaitForSequence(seq []int, start time.Time) bool { break } // no, so let's keep reading until the test runner kills us - pkt := <-pr.ch - pr.log = append(pr.log, newLoggedPacket(pkt, start)) - log.Debugf("got packet: %v", pkt.ID) + pr.appendOneIncomingPacket(start) } // TODO(ainghazal): move the comparison to witness, leave only wait here return slices.Equal(seq, PacketLog(pr.log).IDSequence()) } +// WaitForNumberOfACKs will read from the channel until the given number of acks have been received. func (pr *PacketReader) WaitForNumberOfACKs(total int, start time.Time) { for { // have we read enough acks to call it a day? @@ -127,12 +215,31 @@ func (pr *PacketReader) WaitForNumberOfACKs(total int, start time.Time) { break } // no, so let's keep reading until the test runner kills us - pkt := <-pr.ch - pr.log = append(pr.log, newLoggedPacket(pkt, start)) - log.Debugf("got packet: %v", pkt.ID) + pr.appendOneIncomingPacket(start) } } +// WaitForOrderedPayloadLen will read from the channel until the given number of characters have been read. +func (pr *PacketReader) WaitForOrderedPayloadLen(total int, start time.Time) { + for { + // have we read enough packets to call it a day? + if len(pr.payload) >= total { + break + } + // no, so let's keep reading until the test runner kills us + pr.appendOneIncomingPacket(start) + } +} + +func (pr *PacketReader) appendOneIncomingPacket(t0 time.Time) { + pkt := <-pr.ch + pr.log = append(pr.log, newLoggedPacket(pkt, t0)) + if pkt.Payload != nil { + pr.payload = append(pr.payload, pkt.Payload...) + } + log.Debugf("got packet: %v (%d bytes)", pkt.ID, len(pkt.Payload)) +} + // Log returns the log of the received packets. func (pr *PacketReader) Log() PacketLog { return PacketLog(pr.log) @@ -148,6 +255,12 @@ func NewWitness(r *PacketReader) *Witness { return &Witness{r} } +// NewWitnessFromChannel constructs a Witness from a channel of packets. +func NewWitnessFromChannel(ch <-chan *model.Packet) *Witness { + return NewWitness(NewPacketReader(ch)) + +} + // Log returns the packet log from the internal reader this witness uses. func (w *Witness) Log() PacketLog { return w.reader.Log() @@ -155,12 +268,23 @@ func (w *Witness) Log() PacketLog { // VerifyNumberOfACKs tells the underlying reader to wait for a given number of acks, // returns true if we have the same number of acks. -func (w *Witness) VerifyNumberOfACKs(start, total int, t time.Time) bool { - w.reader.WaitForNumberOfACKs(total, t) +func (w *Witness) VerifyNumberOfACKs(total int, start time.Time) bool { + w.reader.WaitForNumberOfACKs(total, start) return len(w.Log().ACKs()) == total } -// contains check if the element is in the slice. this is expensive, but it's only +// VerifyOrderedPayload checks that the received payload matches the one we expect. +func (w *Witness) VerifyOrderedPayload(payload string, t time.Time) bool { + w.reader.WaitForOrderedPayloadLen(len(payload), t) + return w.reader.Payload() == payload +} + +// Payload returns the string payload reconstructed from the received packets. +func (w *Witness) Payload() string { + return w.reader.Payload() +} + +// contains checks if the element is in the slice. this is expensive, but it's only // for tests and the alternative is to make ackSet public. func contains(slice []int, target int) bool { for _, item := range slice { @@ -170,3 +294,143 @@ func contains(slice []int, target int) bool { } return false } + +// PacketRelay sends any received packet, without modifications. +type PacketRelay struct { + dataIn <-chan *model.Packet + dataOut chan<- *model.Packet + + closeOnce sync.Once + mu sync.Mutex // Guards cancel + cancel chan struct{} +} + +// NewPacketRelay reads packets from one channel and writes them to another. +func NewPacketRelay(dataIn <-chan *model.Packet, dataOut chan<- *model.Packet) *PacketRelay { + return &PacketRelay{ + dataIn: dataIn, + dataOut: dataOut, + + mu: sync.Mutex{}, + cancel: make(chan struct{}), + } +} + +// RelayWithLosses will relay incoming packets according to a vector of packetID that must be dropped. +// To specify repeated losses for a packet ID, the vector of losses must repeat the id several times. +func (pr *PacketRelay) RelayWithLosses(losses []int) { + ctr := makeLossMap(losses) + for { + select { + case <-pr.cancel: + return + case p := <-pr.dataIn: + id := int(p.ID) + cnt, ok := ctr[id] + if !ok || cnt <= 0 { + // not on the loss map, or we already saw the packet enough times + log.Debugf("relay packet: %v (%s)", id, string(p.Payload)) + pr.dataOut <- p + } else { + log.Debugf("relay: drop packet: %v", id) + } + // decrement the counter for this packet ID + ctr[id]-- + } + } +} + +// Stop will stop the relay loop. +func (pr *PacketRelay) Stop() { + pr.closeOnce.Do(func() { + close(pr.cancel) + }) + +} + +// makeLossMap returns a map from packet IDs to int. The value +// of the map represent how many times we have to observe a given packet ID +// before relaying it. +func makeLossMap(l []int) map[int]int { + lc := make(map[int]int) + for _, i := range l { + _, ok := lc[i] + if !ok { + lc[i] = 1 + } else { + lc[i]++ + } + } + return lc +} + +// EchoServer is a dummy server intended for testing. It will: +// - send sequential packets back to a client implementation, containing each the same payload +// and the same packet ID than incoming. +// - write every seen packet on the ACK array for the echo response. +type EchoServer struct { + dataIn chan *model.Packet + dataOut chan *model.Packet + + // local counter for packet id + outPacketID int + + // LocalSessionID is needed to produce incoming packets that pass sanity checks. + LocalSessionID model.SessionID + + // RemoteSessionID is needed to produce ACKs. + RemoteSessionID model.SessionID + + closeOnce sync.Once + mu sync.Mutex // Guards cancel + cancel chan struct{} +} + +// NewEchoServer creates an [EchoServer] given two channels of [model.Packet]s. +func NewEchoServer(dataIn, dataOut chan *model.Packet) *EchoServer { + randomSessionID, err := bytesx.GenRandomBytes(8) + if err != nil { + panic(err) + } + return &EchoServer{ + dataIn: dataIn, + dataOut: dataOut, + outPacketID: 1, + LocalSessionID: model.SessionID(randomSessionID), + RemoteSessionID: [8]byte{}, + closeOnce: sync.Once{}, + mu: sync.Mutex{}, + cancel: make(chan struct{}), + } +} + +// Start starts the [EchoServer]. +func (e *EchoServer) Start() { + for { + select { + case <-e.cancel: + return + case p := <-e.dataIn: + e.replyToPacketWithPayload(p.Payload, p.ID) + } + } +} + +// Stop stops the [EchoServer]. +func (e *EchoServer) Stop() { + e.closeOnce.Do(func() { + close(e.cancel) + }) +} + +func (e *EchoServer) replyToPacketWithPayload(payload []byte, toACK model.PacketID) { + p := &model.Packet{ + Opcode: model.P_CONTROL_V1, + RemoteSessionID: e.RemoteSessionID, + LocalSessionID: e.LocalSessionID, + ID: toACK, + Payload: payload, + ACKs: []model.PacketID{toACK}, + } + e.dataOut <- p +} diff --git a/internal/vpntest/packetio_test.go b/internal/vpntest/packetio_test.go index 3a312c7f..f39f9903 100644 --- a/internal/vpntest/packetio_test.go +++ b/internal/vpntest/packetio_test.go @@ -1,12 +1,59 @@ package vpntest import ( + "bytes" + "reflect" + "slices" "testing" "time" + "github.com/apex/log" "github.com/ooni/minivpn/internal/model" ) +func TestPacketLog_ACKs(t *testing.T) { + tests := []struct { + name string + l PacketLog + want []int + }{ + { + name: "no acks", + l: []*LoggedPacket{}, + want: []int{}, + }, + { + name: "one ack packet", + l: []*LoggedPacket{ + {ACKs: []model.PacketID{0}}, + }, + want: []int{0}, + }, + { + name: "one ack packet with two acks", + l: []*LoggedPacket{ + {ACKs: []model.PacketID{1, 0}}, + }, + want: []int{1, 0}, + }, + { + name: "two ack packets with two acks each", + l: []*LoggedPacket{ + {ACKs: []model.PacketID{1, 0}}, + {ACKs: []model.PacketID{3, 2}}, + }, + want: []int{1, 0, 3, 2}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.l.ACKs(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("PacketLog.ACKs() = %v, want %v", got, tt.want) + } + }) + } +} + func TestPacketReaderWriter(t *testing.T) { type args struct { input []string @@ -79,3 +126,332 @@ func TestPacketReaderWriter(t *testing.T) { }) } } + +func TestPacketWriter_WriteExpandedSequence(t *testing.T) { + tests := []struct { + name string + seq []string + wantIDs []int + }{ + { + name: "test range expansion", + seq: []string{"[1..5] CONTROL_V1 +1ms"}, + wantIDs: []int{1, 2, 3, 4}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := make(chan *model.Packet, 20) + pw := NewPacketWriter(ch) + pw.WriteSequence(tt.seq) + + got := make([]int, 0) + for i := 0; i < len(tt.wantIDs); i++ { + p := <-ch + got = append(got, int(p.ID)) + } + if !slices.Equal(got, tt.wantIDs) { + t.Errorf("WriteExpandedSequence() got = %v, want %v", got, tt.wantIDs) + } + }) + } +} + +func TestWitness_VerifyOrderedPayload(t *testing.T) { + type args struct { + packets []*model.Packet + payload string + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "simple payload, tree packets", + args: args{ + packets: []*model.Packet{ + { + ID: 1, + Payload: []byte("aaa"), + }, + { + ID: 2, + Payload: []byte("bbb"), + }, + { + ID: 3, + Payload: []byte("ccc"), + }, + }, + payload: "aaabbbccc", + }, + want: true, + }, + { + name: "longer payload, two packets", + args: args{ + packets: []*model.Packet{ + { + ID: 1, + Payload: []byte("aaaaaaaaaaaaaaa"), + }, + { + ID: 2, + Payload: []byte("bbbbbbbbbbbbbbb"), + }, + }, + payload: "aaaaaaaaaaaaaaabbbbbbbbbbbbbbb", + }, + want: true, + }, + { + name: "empty payload no packets", + args: args{ + packets: []*model.Packet{}, + payload: "", + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := make(chan *model.Packet, 20) + w := NewWitnessFromChannel(ch) + + for _, p := range tt.args.packets { + ch <- p + } + t0 := time.Now() + if got := w.VerifyOrderedPayload(tt.args.payload, t0); got != tt.want { + t.Errorf("Witness.VerifyOrderedPayload() = %v, want %v", got, tt.want) + } + if w.Payload() != tt.args.payload { + t.Errorf("Witness.Payload() = %v, want %v", w.Payload(), tt.want) + } + }) + } +} + +func TestPacketRelay_RelayWithLosses(t *testing.T) { + log.SetLevel(log.DebugLevel) + type fields struct { + dataIn chan *model.Packet + dataOut chan *model.Packet + RemoteSessionID model.SessionID + } + type args struct { + packetsIn []int + losses []int + wantOut []int + } + tests := []struct { + name string + fields fields + args args + }{ + { + name: "zero loss", + fields: fields{ + dataIn: make(chan *model.Packet, 100), + dataOut: make(chan *model.Packet, 100), + }, + args: args{ + packetsIn: []int{1, 2, 3, 4}, + losses: []int{}, + wantOut: []int{1, 2, 3, 4}, + }, + }, + { + name: "zero loss, repeated ids", + fields: fields{ + dataIn: make(chan *model.Packet, 100), + dataOut: make(chan *model.Packet, 100), + }, + args: args{ + packetsIn: []int{1, 2, 3, 4, 1}, + losses: []int{}, + wantOut: []int{1, 2, 3, 4, 1}, + }, + }, + { + name: "loss for even ids", + fields: fields{ + dataIn: make(chan *model.Packet, 100), + dataOut: make(chan *model.Packet, 100), + }, + args: args{ + packetsIn: []int{1, 2, 3, 4, 5}, + losses: []int{2, 4}, + wantOut: []int{1, 3, 5}, + }, + }, + { + name: "loss for first match", + fields: fields{ + dataIn: make(chan *model.Packet, 100), + dataOut: make(chan *model.Packet, 100), + }, + args: args{ + packetsIn: []int{1, 2, 3, 4, 5, 1, 2}, + losses: []int{1, 2}, + wantOut: []int{3, 4, 5, 1, 2}, + }, + }, + { + name: "loss for two matches", + fields: fields{ + dataIn: make(chan *model.Packet, 100), + dataOut: make(chan *model.Packet, 100), + }, + args: args{ + packetsIn: []int{1, 2, 3, 2, 1, 4, 5, 1, 2}, + losses: []int{1, 1, 2, 2}, + wantOut: []int{3, 4, 5, 1, 2}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pr := NewPacketRelay( + tt.fields.dataIn, + tt.fields.dataOut, + ) + go pr.RelayWithLosses(tt.args.losses) + writer := NewPacketWriter(tt.fields.dataIn) + for _, id := range tt.args.packetsIn { + writer.WritePacketWithID(id) + } + got := readPacketIDSequence(tt.fields.dataOut, len(tt.args.wantOut)) + pr.Stop() + if !slices.Equal(got, tt.args.wantOut) { + t.Errorf("relayWithLosses: got = %v, want %v", got, tt.args.wantOut) + } + }) + } +} + +func readPacketIDSequence(ch chan *model.Packet, wantLen int) []int { + var got []int + for { + pkt := <-ch + got = append(got, int(pkt.ID)) + if len(got) >= wantLen { + break + } + } + return got +} + +func TestPacketWriter_WriteSequenceWithFixedPayload(t *testing.T) { + type args struct { + seq []string + payload string + size int + } + tests := []struct { + name string + args args + }{ + { + name: "string payload with 2 char per packet", + args: args{ + seq: []string{ + "[1] CONTROL_V1 +0ms", + "[2] CONTROL_V1 +0ms", + "[3] CONTROL_V1 +0ms", + "[4] CONTROL_V1 +0ms", + "[5] CONTROL_V1 +0ms", + "[6] CONTROL_V1 +0ms", + "[7] CONTROL_V1 +0ms", + "[8] CONTROL_V1 +0ms", + "[9] CONTROL_V1 +0ms", + "[10] CONTROL_V1 +0ms", + }, + payload: "aabbccddeeffgghhiijj", + size: 2, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := make(chan *model.Packet, 20) + pw := NewPacketWriter(ch) + pw.WriteSequenceWithFixedPayload(tt.args.seq, tt.args.payload, tt.args.size) + + got := "" + for i := 0; i < len(tt.args.seq); i++ { + p := <-ch + got = got + string(p.Payload) + } + if got != tt.args.payload { + t.Errorf("WriteSequenceWithFixedPayload: got = %v, want %v", got, tt.args.payload) + } + }) + } +} + +// test that we're able to start/stop an echo server, and that +// it returns the same that is delivered. +func TestEchoServer_StartStop(t *testing.T) { + type args struct { + dataIn []*model.Packet + } + tests := []struct { + name string + args args + }{ + { + name: "no packets in", + args: args{}, + }, + { + name: "one packet in", + args: args{ + []*model.Packet{ + {ID: 1}}, + }, + }, + { + name: "three packet in with payloads", + args: args{ + []*model.Packet{ + {ID: 1, Payload: []byte("aaa")}, + {ID: 2, Payload: []byte("bbb")}, + {ID: 3, Payload: []byte("ccc")}, + }, + }, + }, + } + for _, tt := range tests { + dataIn := make(chan *model.Packet, 1024) + dataOut := make(chan *model.Packet, 1024) + t.Run(tt.name, func(t *testing.T) { + e := NewEchoServer(dataIn, dataOut) + go e.Start() + got := make([]*model.Packet, 0) + for _, p := range tt.args.dataIn { + dataIn <- p + } + for range tt.args.dataIn { + p := <-dataOut + got = append(got, p) + } + e.Stop() + + if len(got) != len(tt.args.dataIn) { + t.Errorf("TestEchoServer_StartStop: got len = %v, want %v", len(got), len(tt.args.dataIn)) + } + for i := range got { + gotPacket := got[i] + wantPacket := tt.args.dataIn[i] + if gotPacket.ID != wantPacket.ID { + t.Errorf("TestEchoServer_StartStop: packet %d: got ID = %v, want %v", i, gotPacket.ID, wantPacket.ID) + } + if !bytes.Equal(gotPacket.Payload, wantPacket.Payload) { + t.Errorf("TestEchoServer_StartStop: packet %d: got Payload = %v, want Payload %v", i, gotPacket.Payload, wantPacket.Payload) + } + } + }) + } +} diff --git a/internal/vpntest/vpntest.go b/internal/vpntest/vpntest.go index bdf67fce..46e55826 100644 --- a/internal/vpntest/vpntest.go +++ b/internal/vpntest/vpntest.go @@ -28,8 +28,10 @@ type TestPacket struct { IAT time.Duration } -// the test packet string is in the form: -// "[ID] OPCODE +42ms" +// NewTestPacketFromString constructs a new TestPacket. The input +// representation for the test packet string is in the form: +// "[ID] OPCODE (ack:) +42ms" +// where the ack array is optional. func NewTestPacketFromString(s string) (*TestPacket, error) { parts := strings.Split(s, " +")