Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ckousik committed Dec 6, 2022
1 parent f3b4f4d commit 158b01f
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 24 deletions.
26 changes: 15 additions & 11 deletions p2p/transport/webrtc/datachannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,18 @@ var _ network.MuxedStream = &dataChannel{}
const (
// maxMessageSize is limited to 16384 bytes in the SDP.
maxMessageSize int = 16384
// maxMessageSize is set to 1MB since pion SCTP streams have
// an internal buffer size of 1MB by default. Currently, there is
// no method to change this value when creating a datachannel
// or a PeerConnection.
// Pion SCTP association has an internal receive buffer of 1MB (roughly, 1MB per connection).
// Currently, there is no way to change this value via the WebRTC API.
// https://github.com/pion/sctp/blob/c0159aa2d49c240362038edf88baa8a9e6cfcede/association.go#L47
maxBufferedAmount int = 1024 * 1024
maxBufferedAmount int = 2 * maxMessageSize
// bufferedAmountLowThreshold and maxBufferedAmount are bound
// to a stream but congestion control is done on the whole
// SCTP association. This means that a single stream can monopolize
// the complete congestion control window (cwnd) if it does not
// read stream data and it's remote continues to send. We can
// add messages to the send buffer once there is space for 1 full
// sized message.
bufferedAmountLowThreshold uint64 = uint64(maxMessageSize)
bufferedAmountLowThreshold uint64 = uint64(maxBufferedAmount) / 2

protoOverhead int = 5
varintOverhead int = 2
Expand Down Expand Up @@ -235,23 +233,23 @@ func (d *dataChannel) partialWrite(b []byte) (int, error) {
case <-timeout:
return 0, os.ErrDeadlineExceeded
case <-writeAvailable:
return d.writeMessage(msg)
return len(b), d.writeMessage(msg)
case <-d.ctx.Done():
return 0, io.ErrClosedPipe
case <-deadlineUpdated:

}
} else {
return d.writeMessage(msg)
return len(b), d.writeMessage(msg)
}
}
}

func (d *dataChannel) writeMessage(msg *pb.Message) (int, error) {
func (d *dataChannel) writeMessage(msg *pb.Message) error {
err := d.writer.WriteMsg(msg)
// this only returns the number of bytes sent from the buffer
// requested by the user.
return len(msg.GetMessage()), err
return err

}

Expand Down Expand Up @@ -333,7 +331,7 @@ func (d *dataChannel) Reset() error {
var err error
d.resetOnce.Do(func() {
msg := &pb.Message{Flag: pb.Message_RESET.Enum()}
_, err = d.writeMessage(msg)
err = d.writeMessage(msg)
err = d.Close()
})
return err
Expand Down Expand Up @@ -373,6 +371,12 @@ func (d *dataChannel) getState() channelState {
return d.state
}

// readLoop is required for both reads and writes since calling `Read`
// on the underlying datachannel blocks indefinitely until data is available
// or the datachannel is closed. Having Read run in a separate Goroutine driven
// by the stream's `Read` call allows setting deadlines on the stream's `Read`
// and also allows `Write` to read message flags in a non-blocking way after the
// stream stops reading.
func (d *dataChannel) readLoop() {
defer d.wg.Done()
for {
Expand Down
45 changes: 44 additions & 1 deletion p2p/transport/webrtc/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,53 @@ func TestTransportWebRTC_StreamSetWriteDeadline(t *testing.T) {
require.NoError(t, err)

stream.SetWriteDeadline(time.Now().Add(200 * time.Millisecond))
_, err = stream.Write(make([]byte, 2*maxBufferedAmount))
largeBuffer := make([]byte, 2*1024*1024)
_, err = stream.Write(largeBuffer)
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
}

func TestTransportWebRTC_StreamWriteBufferContention(t *testing.T) {
tr, listeningPeer := getTransport(t)
listenMultiaddr, err := multiaddr.NewMultiaddr(fmt.Sprintf("/ip4/%s/udp/0/webrtc", listenerIp))
require.NoError(t, err)
listener, err := tr.Listen(listenMultiaddr)
require.NoError(t, err)

tr1, connectingPeer := getTransport(t)

for i := 0; i < 2; i++ {
go func() {
lconn, err := listener.Accept()
require.NoError(t, err)
require.Equal(t, connectingPeer, lconn.RemotePeer())
_, err = lconn.AcceptStream()
require.NoError(t, err)
}()

}

conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)

errC := make(chan error)
// writers
for i := 0; i < 2; i++ {
go func() {
stream, err := conn.OpenStream(context.Background())
require.NoError(t, err)

stream.SetWriteDeadline(time.Now().Add(200 * time.Millisecond))
largeBuffer := make([]byte, 2*1024*1024)
_, err = stream.Write(largeBuffer)
errC <- err
}()
}

require.ErrorIs(t, <-errC, os.ErrDeadlineExceeded)
require.ErrorIs(t, <-errC, os.ErrDeadlineExceeded)

}

func TestTransportWebRTC_ReadPartialMessage(t *testing.T) {
tr, listeningPeer := getTransport(t)
listenMultiaddr, err := multiaddr.NewMultiaddr(fmt.Sprintf("/ip4/%s/udp/0/webrtc", listenerIp))
Expand Down
14 changes: 14 additions & 0 deletions p2p/transport/webrtc/udpmux/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func (mux *udpMux) RemoveConnByUfrag(ufrag string) {
for _, isIPv6 := range []bool{true, false} {
key := ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}
if conn, ok := mux.ufragMap[key]; ok {
_ = conn.closeConnection()
removedAddresses = append(removedAddresses, conn.addresses...)
delete(mux.ufragMap, key)
}
Expand Down Expand Up @@ -169,6 +170,19 @@ func (mux *udpMux) readLoop() {
}
}

func (mux *udpMux) hasConn(ufrag string) net.PacketConn {
mux.mu.Lock()
defer mux.mu.Unlock()

for _, isIPv6 := range []bool{true, false} {
key := ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}
if conn, ok := mux.ufragMap[key]; ok {
return conn
}
}
return nil
}

func ufragFromStunMessage(msg *stun.Message, local_ufrag bool) (string, error) {
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr != nil {
Expand Down
63 changes: 63 additions & 0 deletions p2p/transport/webrtc/udpmux/mux_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package udpmux

import (
"net"
"testing"
"time"

"github.com/stretchr/testify/require"
)

var _ net.PacketConn = dummyPacketConn{}

type dummyPacketConn struct{}

// Close implements net.PacketConn
func (dummyPacketConn) Close() error {
return nil
}

// LocalAddr implements net.PacketConn
func (dummyPacketConn) LocalAddr() net.Addr {
return nil
}

// ReadFrom implements net.PacketConn
func (dummyPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return 0, &net.UDPAddr{}, nil
}

// SetDeadline implements net.PacketConn
func (dummyPacketConn) SetDeadline(t time.Time) error {
return nil
}

// SetReadDeadline implements net.PacketConn
func (dummyPacketConn) SetReadDeadline(t time.Time) error {
return nil
}

// SetWriteDeadline implements net.PacketConn
func (dummyPacketConn) SetWriteDeadline(t time.Time) error {
return nil
}

// WriteTo implements net.PacketConn
func (dummyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, nil
}

func TestUDPMux_RemoveConnectionOnClose(t *testing.T) {
mux := NewUDPMux(dummyPacketConn{}, nil)
conn, err := mux.GetConn("test", false)
require.NoError(t, err)
require.NotNil(t, conn)

m := mux.(*udpMux)
require.NotNil(t, m.hasConn("test"))

err = conn.Close()
require.NoError(t, err)

require.Nil(t, m.hasConn("test"))
}
32 changes: 20 additions & 12 deletions p2p/transport/webrtc/udpmux/muxed_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ package udpmux

import (
"context"
"fmt"
"net"
"time"
)

var _ net.PacketConn = &muxedConnection{}

type muxedConnection struct {
ctx context.Context
cancel context.CancelFunc
buffer *packetBuffer
ctx context.Context
cancelFunc context.CancelFunc
buffer *packetBuffer
// list of remote addresses associated with this connection.
// this is useful as a mapping from [address] -> ufrag
addresses []string
Expand All @@ -22,11 +23,11 @@ type muxedConnection struct {
func newMuxedConnection(mux *udpMux, ufrag string) *muxedConnection {
ctx, cancel := context.WithCancel(context.Background())
return &muxedConnection{
ctx: ctx,
cancel: cancel,
buffer: newPacketBuffer(ctx),
ufrag: ufrag,
mux: mux,
ctx: ctx,
cancelFunc: cancel,
buffer: newPacketBuffer(ctx),
ufrag: ufrag,
mux: mux,
}
}

Expand All @@ -36,12 +37,9 @@ func (conn *muxedConnection) push(buf []byte, addr net.Addr) error {

// Close implements net.PacketConn
func (conn *muxedConnection) Close() error {
select {
case <-conn.ctx.Done():
if err := conn.closeConnection(); err != nil {
return nil
default:
}
conn.cancel()
conn.mux.RemoveConnByUfrag(conn.ufrag)
return nil
}
Expand Down Expand Up @@ -75,3 +73,13 @@ func (*muxedConnection) SetWriteDeadline(t time.Time) error {
func (conn *muxedConnection) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return conn.mux.writeTo(p, addr)
}

func (conn *muxedConnection) closeConnection() error {
select {
case <-conn.ctx.Done():
return fmt.Errorf("already closed")
default:
}
conn.cancelFunc()
return nil
}

0 comments on commit 158b01f

Please sign in to comment.