diff --git a/association.go b/association.go index 91b92fc3..bf8f21d1 100644 --- a/association.go +++ b/association.go @@ -221,8 +221,9 @@ type Association struct { delayedAckTriggered bool immediateAckTriggered bool - name string - log logging.LeveledLogger + name string + log logging.LeveledLogger + streamVersion uint32 } // Config collects the arguments to createAssociation construction into @@ -1368,6 +1369,7 @@ func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream streamIdentifier: streamIdentifier, reassemblyQueue: newReassemblyQueue(streamIdentifier), log: a.log, + version: atomic.AddUint32(&a.streamVersion, 1), name: fmt.Sprintf("%d:%s", streamIdentifier, a.name), } @@ -2088,10 +2090,14 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1 dataLen := uint32(len(c.userData)) if dataLen == 0 { sisToReset = append(sisToReset, c.streamIdentifier) - err := a.pendingQueue.pop(c) - if err != nil { - a.log.Errorf("failed to pop from pending queue: %s", err.Error()) - } + a.popPendingDataChunksToDrop(c) + continue + } + + s, ok := a.streams[c.streamIdentifier] + + if !ok || s.State() > StreamStateOpen || s.version != c.streamVersion { + a.popPendingDataChunksToDrop(c) continue } @@ -2123,6 +2129,13 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1 return chunks, sisToReset } +func (a *Association) popPendingDataChunksToDrop(c *chunkPayloadData) { + err := a.pendingQueue.pop(c) + if err != nil { + a.log.Errorf("failed to pop from pending queue: %s", err.Error()) + } +} + // bundleDataChunksIntoPackets packs DATA chunks into packets. It tries to bundle // DATA chunks into a packet so long as the resulting packet size does not exceed // the path MTU. diff --git a/chunk_payload_data.go b/chunk_payload_data.go index eecc7c8c..294de8ab 100644 --- a/chunk_payload_data.go +++ b/chunk_payload_data.go @@ -71,7 +71,8 @@ type chunkPayloadData struct { // chunk is still in the inflight queue retransmit bool - head *chunkPayloadData // link to the head of the fragment + head *chunkPayloadData // link to the head of the fragment + streamVersion uint32 } const ( diff --git a/stream.go b/stream.go index fd2c5fc7..c1fbece3 100644 --- a/stream.go +++ b/stream.go @@ -24,12 +24,12 @@ const ( // StreamState is an enum for SCTP Stream state field // This field identifies the state of stream. -type StreamState int +type StreamState int32 // StreamState enums const ( StreamStateOpen StreamState = iota // Stream object starts with StreamStateOpen - StreamStateClosing // Outgoing stream is being reset + StreamStateClosing // Stream is closed by remote StreamStateClosed // Stream has been closed ) @@ -71,6 +71,7 @@ type Stream struct { state StreamState log logging.LeveledLogger name string + version uint32 } // StreamIdentifier returns the Stream identifier associated to the stream. @@ -296,6 +297,7 @@ func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPa copy(userData, raw[i:i+fragmentSize]) chunk := &chunkPayloadData{ + streamVersion: s.version, streamIdentifier: s.streamIdentifier, userData: userData, unordered: unordered, @@ -338,16 +340,22 @@ func (s *Stream) Close() error { s.lock.Lock() defer s.lock.Unlock() - s.log.Debugf("[%s] Close: state=%s", s.name, s.state.String()) + state := s.State() + s.log.Debugf("[%s] Close: state=%s", s.name, state.String()) - if s.state == StreamStateOpen { - if s.readErr == nil { - s.state = StreamStateClosing - } else { - s.state = StreamStateClosed - } - s.log.Debugf("[%s] state change: open => %s", s.name, s.state.String()) + switch state { + case StreamStateOpen: + s.SetState(StreamStateClosed) + s.log.Debugf("[%s] state change: open => closed", s.name) + s.readErr = io.EOF + s.readNotifier.Broadcast() + return s.streamIdentifier, true + case StreamStateClosing: + s.SetState(StreamStateClosed) + s.log.Debugf("[%s] state change: closing => closed", s.name) return s.streamIdentifier, true + case StreamStateClosed: + return s.streamIdentifier, false } return s.streamIdentifier, false }(); resetOutbound { @@ -434,7 +442,8 @@ func (s *Stream) onInboundStreamReset() { s.lock.Lock() defer s.lock.Unlock() - s.log.Debugf("[%s] onInboundStreamReset: state=%s", s.name, s.state.String()) + state := s.State() + s.log.Debugf("[%s] onInboundStreamReset: state=%s", s.name, state.String()) // No more inbound data to read. Unblock the read with io.EOF. // This should cause DCEP layer (datachannel package) to call Close() which @@ -445,19 +454,21 @@ func (s *Stream) onInboundStreamReset() { // outgoing stream. When the peer sees that an incoming stream was // reset, it also resets its corresponding outgoing stream. Once this // is completed, the data channel is closed. + if state == StreamStateOpen { + s.log.Debugf("[%s] state change: open => closing", s.name) + s.SetState(StreamStateClosing) + } s.readErr = io.EOF s.readNotifier.Broadcast() - - if s.state == StreamStateClosing { - s.log.Debugf("[%s] state change: closing => closed", s.name) - s.state = StreamStateClosed - } } -// State return the stream state. +// State atomically returns the stream state. func (s *Stream) State() StreamState { - s.lock.RLock() - defer s.lock.RUnlock() - return s.state + return StreamState(atomic.LoadInt32((*int32)(&s.state))) +} + +// SetState atomically sets the stream state. +func (s *Stream) SetState(newState StreamState) { + atomic.StoreInt32((*int32)(&s.state), int32(newState)) } diff --git a/vnet_test.go b/vnet_test.go index 996f495e..bdf79946 100644 --- a/vnet_test.go +++ b/vnet_test.go @@ -1,7 +1,6 @@ package sctp import ( - "bytes" "fmt" "math/rand" "net" @@ -387,91 +386,84 @@ func TestRwndFull(t *testing.T) { } func TestStreamClose(t *testing.T) { - loopBackTest := func(t *testing.T, dropReconfigChunk bool) { - lim := test.TimeOut(time.Second * 10) - defer lim.Stop() + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() - loggerFactory := logging.NewDefaultLoggerFactory() - log := loggerFactory.NewLogger("test") + loggerFactory := logging.NewDefaultLoggerFactory() + log := loggerFactory.NewLogger("test") - venv, err := buildVNetEnv(&vNetEnvConfig{ - loggerFactory: loggerFactory, - log: log, - }) + venv, err := buildVNetEnv(&vNetEnvConfig{ + loggerFactory: loggerFactory, + log: log, + }) + if !assert.NoError(t, err, "should succeed") { + return + } + if !assert.NotNil(t, venv, "should not be nil") { + return + } + defer venv.wan.Stop() // nolint:errcheck + + serverStreamReady := make(chan struct{}) + clientStreamReady := make(chan struct{}) + clientStartClose := make(chan struct{}) + serverStreamClosed := make(chan struct{}) + shutDownClient := make(chan struct{}) + clientShutDown := make(chan struct{}) + serverShutDown := make(chan struct{}) + + go func() { + defer close(serverShutDown) + // connected UDP conn for server + conn, err := venv.net0.DialUDP("udp4", + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, + ) if !assert.NoError(t, err, "should succeed") { return } - if !assert.NotNil(t, venv, "should not be nil") { + defer conn.Close() // nolint:errcheck + + // server association + assoc, err := Server(Config{ + NetConn: conn, + LoggerFactory: loggerFactory, + }) + if !assert.NoError(t, err, "should succeed") { return } - defer venv.wan.Stop() // nolint:errcheck - - clientShutDown := make(chan struct{}) - serverShutDown := make(chan struct{}) + defer assoc.Close() // nolint:errcheck - const numMessages = 10 - const messageSize = 1024 - var messages [][]byte - var numServerReceived int - var numClientReceived int + log.Info("server handshake complete") - for i := 0; i < numMessages; i++ { - bytes := make([]byte, messageSize) - messages = append(messages, bytes) + stream, err := assoc.AcceptStream() + if !assert.NoError(t, err, "should succeed") { + return } + defer stream.Close() // nolint:errcheck - go func() { - defer close(serverShutDown) - // connected UDP conn for server - conn, innerErr := venv.net0.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, - ) - if !assert.NoError(t, innerErr, "should succeed") { - return - } - defer conn.Close() // nolint:errcheck - - // server association - assoc, innerErr := Server(Config{ - NetConn: conn, - LoggerFactory: loggerFactory, - }) - if !assert.NoError(t, innerErr, "should succeed") { - return + buf := make([]byte, 1500) + for { + n, err := stream.Read(buf) + if err != nil { + t.Logf("server: Read returned %v", err) + break } - defer assoc.Close() // nolint:errcheck - - log.Info("server handshake complete") - stream, innerErr := assoc.AcceptStream() - if !assert.NoError(t, innerErr, "should succeed") { - return + if !assert.Equal(t, "HELLO", string(buf[:n]), "should receive HELLO") { + continue } - assert.Equal(t, StreamStateOpen, stream.State()) - buf := make([]byte, 1500) - for { - n, errRead := stream.Read(buf) - if errRead != nil { - log.Infof("server: Read returned %v", errRead) - _ = stream.Close() // nolint:errcheck - assert.Equal(t, StreamStateClosed, stream.State()) - break - } - - log.Infof("server: received %d bytes (%d)", n, numServerReceived) - assert.Equal(t, 0, bytes.Compare(buf[:n], messages[numServerReceived]), "should receive HELLO") - - _, err2 := stream.Write(buf[:n]) - assert.NoError(t, err2, "should succeed") + log.Info("server stream ready") + close(serverStreamReady) + } - numServerReceived++ - } - // don't close association until the client's stream routine is complete - <-clientShutDown - }() + close(serverStreamClosed) + log.Info("server closing") + }() + go func() { + defer close(clientShutDown) // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, @@ -480,7 +472,6 @@ func TestStreamClose(t *testing.T) { if !assert.NoError(t, err, "should succeed") { return } - defer conn.Close() // nolint:errcheck // client association assoc, err := Client(Config{ @@ -498,51 +489,45 @@ func TestStreamClose(t *testing.T) { if !assert.NoError(t, err, "should succeed") { return } - assert.Equal(t, StreamStateOpen, stream.State()) + stream.SetReliabilityParams(false, ReliabilityTypeReliable, 0) - // begin client read-loop + // Send a message to let server side stream to open + _, err = stream.Write([]byte("HELLO")) + if !assert.NoError(t, err, "should succeed") { + return + } + buf := make([]byte, 1500) + done := make(chan struct{}) go func() { - defer close(clientShutDown) for { - n, err2 := stream.Read(buf) + log.Info("client read") + _, err2 := stream.Read(buf) if err2 != nil { - log.Infof("client: Read returned %v", err2) - assert.Equal(t, StreamStateClosed, stream.State()) + t.Logf("client: Read returned %v", err2) break } - - log.Infof("client: received %d bytes (%d)", n, numClientReceived) - assert.Equal(t, 0, bytes.Compare(buf[:n], messages[numClientReceived]), "should receive HELLO") - numClientReceived++ } + close(done) }() - // Send messages to the server - for i := 0; i < numMessages; i++ { - _, err = stream.Write(messages[i]) - assert.NoError(t, err, "should succeed") - } + log.Info("client stream ready") + close(clientStreamReady) - if dropReconfigChunk { - venv.dropNextReconfigChunk(1) - } + <-clientStartClose + + // drop next 1 RECONFIG chunk + venv.dropNextReconfigChunk(1) - // Immediately close the stream err = stream.Close() assert.NoError(t, err, "should succeed") - assert.Equal(t, StreamStateClosing, stream.State()) log.Info("client wait for exit reading..") - <-clientShutDown - - assert.Equal(t, numMessages, numServerReceived, "all messages should be received") - assert.Equal(t, numMessages, numClientReceived, "all messages should be received") + <-done - _, err = stream.Write([]byte{1}) + <-shutDownClient - assert.Equal(t, err, errStreamClosed, "after closed should not allow write") // Check if RECONFIG was actually dropped assert.Equal(t, 0, venv.numToDropReconfig, "should be zero") @@ -554,15 +539,26 @@ func TestStreamClose(t *testing.T) { pendingReconfigs := len(assoc.reconfigs) assoc.lock.RUnlock() assert.Equal(t, 0, pendingReconfigs, "should be zero") - } - t.Run("without dropping Reconfig", func(t *testing.T) { - loopBackTest(t, false) - }) + log.Info("client closing") + }() - t.Run("with dropping Reconfig", func(t *testing.T) { - loopBackTest(t, true) - }) + // wait until both establish a stream + <-clientStreamReady + <-serverStreamReady + + log.Info("stream ready") + + // let client begin writing + log.Info("client start closing") + close(clientStartClose) + + <-serverStreamClosed + close(shutDownClient) + + <-clientShutDown + <-serverShutDown + log.Info("all done") } // this test case reproduces the issue mentioned in