Skip to content

Commit

Permalink
add a shaper for fair transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
xtaci committed Sep 10, 2019
1 parent 6aa95ef commit 78fdaa9
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 9 deletions.
41 changes: 37 additions & 4 deletions session.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package smux

import (
"container/heap"
"encoding/binary"
"io"
"net"
Expand All @@ -22,6 +23,7 @@ var (
)

type writeRequest struct {
prio uint64
frame Frame
result chan writeResult
}
Expand Down Expand Up @@ -73,6 +75,7 @@ type Session struct {

deadline atomic.Value

shaper chan writeRequest // a shaper for writing
writes chan writeRequest
}

Expand All @@ -85,6 +88,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
s.chAccepts = make(chan *Stream, defaultAcceptBacklog)
s.bucket = int32(config.MaxReceiveBuffer)
s.bucketNotify = make(chan struct{}, 1)
s.shaper = make(chan writeRequest)
s.writes = make(chan writeRequest)
s.chSocketReadError = make(chan struct{})
s.chSocketWriteError = make(chan struct{})
Expand All @@ -95,6 +99,8 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
} else {
s.nextStreamID = 0
}

go s.shaperLoop()
go s.recvLoop()
go s.sendLoop()
go s.keepalive()
Expand Down Expand Up @@ -357,7 +363,7 @@ func (s *Session) keepalive() {
for {
select {
case <-tickerPing.C:
s.writeFrameInternal(newFrame(cmdNOP, 0), tickerPing.C)
s.writeFrameInternal(newFrame(cmdNOP, 0), tickerPing.C, 0)
s.notifyBucket() // force a signal to the recvLoop
case <-tickerTimeout.C:
if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
Expand All @@ -370,6 +376,33 @@ func (s *Session) keepalive() {
}
}

// shaper shapes the sending sequence among streams
func (s *Session) shaperLoop() {
var reqs shaperHeap
var next writeRequest
var chWrite chan writeRequest

for {
if len(reqs) > 0 {
chWrite = s.writes
next = heap.Pop(&reqs).(writeRequest)
} else {
chWrite = nil
}

select {
case <-s.die:
return
case r := <-s.shaper:
if chWrite != nil { // next is valid, reshape
heap.Push(&reqs, next)
}
heap.Push(&reqs, r)
case chWrite <- next:
}
}
}

func (s *Session) sendLoop() {
var buf []byte
var n int
Expand Down Expand Up @@ -428,17 +461,17 @@ func (s *Session) sendLoop() {
// writeFrame writes the frame to the underlying connection
// and returns the number of bytes written if successful
func (s *Session) writeFrame(f Frame) (n int, err error) {
return s.writeFrameInternal(f, nil)
return s.writeFrameInternal(f, nil, 0)
}

// internal writeFrame version to support deadline used in keepalive
func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time) (int, error) {
func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) {
req := writeRequest{
frame: f,
result: make(chan writeResult, 1),
}
select {
case s.writes <- req:
case s.shaper <- req:
case <-s.die:
return 0, errors.WithStack(io.ErrClosedPipe)
case <-s.chSocketWriteError:
Expand Down
8 changes: 4 additions & 4 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ func TestWriteFrameInternal(t *testing.T) {
session.Close()
for i := 0; i < 100; i++ {
f := newFrame(byte(rand.Uint32()), rand.Uint32())
session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout))
session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), 0)
}

// random cmds
Expand All @@ -652,14 +652,14 @@ func TestWriteFrameInternal(t *testing.T) {
session, _ = Client(cli, nil)
for i := 0; i < 100; i++ {
f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32())
session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout))
session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), 0)
}
//deadline occur
{
c := make(chan time.Time)
close(c)
f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32())
_, err := session.writeFrameInternal(f, c)
_, err := session.writeFrameInternal(f, c, 0)
if !strings.Contains(err.Error(), "timeout") {
t.Fatal("write frame with deadline failed", err)
}
Expand All @@ -684,7 +684,7 @@ func TestWriteFrameInternal(t *testing.T) {
time.Sleep(time.Second)
close(c)
}()
_, err = session.writeFrameInternal(f, c)
_, err = session.writeFrameInternal(f, c, 0)
if !strings.Contains(err.Error(), "closed pipe") {
t.Fatal("write frame with to closed conn failed", err)
}
Expand Down
16 changes: 16 additions & 0 deletions shaper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package smux

type shaperHeap []writeRequest

func (h shaperHeap) Len() int { return len(h) }
func (h shaperHeap) Less(i, j int) bool { return h[i].prio < h[j].prio }
func (h shaperHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *shaperHeap) Push(x interface{}) { *h = append(*h, x.(writeRequest)) }

func (h *shaperHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
6 changes: 5 additions & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ type Stream struct {
// deadlines
readDeadline atomic.Value
writeDeadline atomic.Value

// count writes
numWrite uint64
}

// newStream initiates a Stream struct
Expand Down Expand Up @@ -132,7 +135,8 @@ func (s *Stream) Write(b []byte) (n int, err error) {
}
frame.data = bts[:sz]
bts = bts[sz:]
n, err := s.sess.writeFrameInternal(frame, deadline)
n, err := s.sess.writeFrameInternal(frame, deadline, s.numWrite)
s.numWrite++
sent += n
if err != nil {
return sent, errors.WithStack(err)
Expand Down

0 comments on commit 78fdaa9

Please sign in to comment.