diff --git a/session.go b/session.go index 5f6bdb4..580ab1d 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package smux import ( + "container/heap" "encoding/binary" "io" "net" @@ -22,6 +23,7 @@ var ( ) type writeRequest struct { + prio uint64 frame Frame result chan writeResult } @@ -73,6 +75,7 @@ type Session struct { deadline atomic.Value + shaper chan writeRequest // a shaper for writing writes chan writeRequest } @@ -85,6 +88,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { s.chAccepts = make(chan *Stream, defaultAcceptBacklog) s.bucket = int32(config.MaxReceiveBuffer) s.bucketNotify = make(chan struct{}, 1) + s.shaper = make(chan writeRequest) s.writes = make(chan writeRequest) s.chSocketReadError = make(chan struct{}) s.chSocketWriteError = make(chan struct{}) @@ -95,6 +99,8 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { } else { s.nextStreamID = 0 } + + go s.shaperLoop() go s.recvLoop() go s.sendLoop() go s.keepalive() @@ -357,7 +363,7 @@ func (s *Session) keepalive() { for { select { case <-tickerPing.C: - s.writeFrameInternal(newFrame(cmdNOP, 0), tickerPing.C) + s.writeFrameInternal(newFrame(cmdNOP, 0), tickerPing.C, 0) s.notifyBucket() // force a signal to the recvLoop case <-tickerTimeout.C: if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { @@ -370,6 +376,33 @@ func (s *Session) keepalive() { } } +// shaper shapes the sending sequence among streams +func (s *Session) shaperLoop() { + var reqs shaperHeap + var next writeRequest + var chWrite chan writeRequest + + for { + if len(reqs) > 0 { + chWrite = s.writes + next = heap.Pop(&reqs).(writeRequest) + } else { + chWrite = nil + } + + select { + case <-s.die: + return + case r := <-s.shaper: + if chWrite != nil { // next is valid, reshape + heap.Push(&reqs, next) + } + heap.Push(&reqs, r) + case chWrite <- next: + } + } +} + func (s *Session) sendLoop() { var buf []byte var n int @@ -428,17 +461,17 @@ func (s *Session) sendLoop() { // writeFrame writes the frame to the underlying connection // and returns the number of bytes written if successful func (s *Session) writeFrame(f Frame) (n int, err error) { - return s.writeFrameInternal(f, nil) + return s.writeFrameInternal(f, nil, 0) } // internal writeFrame version to support deadline used in keepalive -func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time) (int, error) { +func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) { req := writeRequest{ frame: f, result: make(chan writeResult, 1), } select { - case s.writes <- req: + case s.shaper <- req: case <-s.die: return 0, errors.WithStack(io.ErrClosedPipe) case <-s.chSocketWriteError: diff --git a/session_test.go b/session_test.go index dde6ac6..1812eae 100644 --- a/session_test.go +++ b/session_test.go @@ -640,7 +640,7 @@ func TestWriteFrameInternal(t *testing.T) { session.Close() for i := 0; i < 100; i++ { f := newFrame(byte(rand.Uint32()), rand.Uint32()) - session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout)) + session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), 0) } // random cmds @@ -652,14 +652,14 @@ func TestWriteFrameInternal(t *testing.T) { session, _ = Client(cli, nil) for i := 0; i < 100; i++ { f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32()) - session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout)) + session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), 0) } //deadline occur { c := make(chan time.Time) close(c) f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32()) - _, err := session.writeFrameInternal(f, c) + _, err := session.writeFrameInternal(f, c, 0) if !strings.Contains(err.Error(), "timeout") { t.Fatal("write frame with deadline failed", err) } @@ -684,7 +684,7 @@ func TestWriteFrameInternal(t *testing.T) { time.Sleep(time.Second) close(c) }() - _, err = session.writeFrameInternal(f, c) + _, err = session.writeFrameInternal(f, c, 0) if !strings.Contains(err.Error(), "closed pipe") { t.Fatal("write frame with to closed conn failed", err) } diff --git a/shaper.go b/shaper.go new file mode 100644 index 0000000..be03406 --- /dev/null +++ b/shaper.go @@ -0,0 +1,16 @@ +package smux + +type shaperHeap []writeRequest + +func (h shaperHeap) Len() int { return len(h) } +func (h shaperHeap) Less(i, j int) bool { return h[i].prio < h[j].prio } +func (h shaperHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *shaperHeap) Push(x interface{}) { *h = append(*h, x.(writeRequest)) } + +func (h *shaperHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} diff --git a/stream.go b/stream.go index 8d7bbfa..aa19f38 100644 --- a/stream.go +++ b/stream.go @@ -35,6 +35,9 @@ type Stream struct { // deadlines readDeadline atomic.Value writeDeadline atomic.Value + + // count writes + numWrite uint64 } // newStream initiates a Stream struct @@ -132,7 +135,8 @@ func (s *Stream) Write(b []byte) (n int, err error) { } frame.data = bts[:sz] bts = bts[sz:] - n, err := s.sess.writeFrameInternal(frame, deadline) + n, err := s.sess.writeFrameInternal(frame, deadline, s.numWrite) + s.numWrite++ sent += n if err != nil { return sent, errors.WithStack(err)