diff --git a/.travis.yml b/.travis.yml index e1d30fa..a90ca48 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: go go: - - 1.9.x - 1.10.x - 1.11.x + - 1.12.x before_install: - go get -t -v ./... @@ -11,7 +11,7 @@ install: - go get github.com/xtaci/smux script: - - go test -coverprofile=coverage.txt -covermode=atomic -bench . + - go test -coverprofile=coverage.txt -covermode=atomic -bench . -v after_success: - bash <(curl -s https://codecov.io/bash) diff --git a/frame.go b/frame.go index 71d3d44..eac341d 100644 --- a/frame.go +++ b/frame.go @@ -14,6 +14,9 @@ const ( // cmds cmdFIN // stream close, a.k.a EOF mark cmdPSH // data push cmdNOP // no operation + cmdACK // for test RTT + cmdFUL // buffer full + cmdEMP // buffer empty ) const ( diff --git a/mux.go b/mux.go index 3cc8f11..7127c0d 100644 --- a/mux.go +++ b/mux.go @@ -1,7 +1,6 @@ package smux import ( - "fmt" "io" "time" @@ -24,15 +23,27 @@ type Config struct { // MaxReceiveBuffer is used to control the maximum // number of data in the buffer pool MaxReceiveBuffer int + + // Enable Stream buffer + EnableStreamBuffer bool + + // maximum bytes that each Stream can use + MaxStreamBuffer int + + // for initial boost + BoostTimeout time.Duration } // DefaultConfig is used to return a default configuration func DefaultConfig() *Config { return &Config{ - KeepAliveInterval: 10 * time.Second, - KeepAliveTimeout: 30 * time.Second, - MaxFrameSize: 32768, - MaxReceiveBuffer: 4194304, + KeepAliveInterval: 2500 * time.Millisecond, + KeepAliveTimeout: 7500 * time.Millisecond, // RTT usually < 7500ms + MaxFrameSize: 32768, + MaxReceiveBuffer: 4 * 1024 * 1024, + EnableStreamBuffer: true, + MaxStreamBuffer: 200 * 8 * 1024, + BoostTimeout: 10 * time.Second, } } @@ -41,9 +52,6 @@ func VerifyConfig(config *Config) error { if config.KeepAliveInterval == 0 { return errors.New("keep-alive interval must be positive") } - if config.KeepAliveTimeout < config.KeepAliveInterval { - return fmt.Errorf("keep-alive timeout must be larger than keep-alive interval") - } if config.MaxFrameSize <= 0 { return errors.New("max frame size must be positive") } @@ -53,6 +61,9 @@ func VerifyConfig(config *Config) error { if config.MaxReceiveBuffer <= 0 { return errors.New("max receive buffer must be positive") } + if config.MaxStreamBuffer <= 0 { + return errors.New("max stream receive buffer must be positive") + } return nil } diff --git a/mux_test.go b/mux_test.go index 638e67c..726ab30 100644 --- a/mux_test.go +++ b/mux_test.go @@ -26,8 +26,7 @@ func TestConfig(t *testing.T) { } config = DefaultConfig() - config.KeepAliveInterval = 10 - config.KeepAliveTimeout = 5 + config.MaxFrameSize = 0 err = VerifyConfig(config) t.Log(err) if err == nil { @@ -35,7 +34,7 @@ func TestConfig(t *testing.T) { } config = DefaultConfig() - config.MaxFrameSize = 0 + config.MaxFrameSize = 65536 err = VerifyConfig(config) t.Log(err) if err == nil { @@ -43,7 +42,7 @@ func TestConfig(t *testing.T) { } config = DefaultConfig() - config.MaxFrameSize = 65536 + config.MaxReceiveBuffer = 0 err = VerifyConfig(config) t.Log(err) if err == nil { @@ -51,7 +50,7 @@ func TestConfig(t *testing.T) { } config = DefaultConfig() - config.MaxReceiveBuffer = 0 + config.MaxStreamBuffer = 0 err = VerifyConfig(config) t.Log(err) if err == nil { diff --git a/session.go b/session.go index d0c3a13..1311b3d 100644 --- a/session.go +++ b/session.go @@ -14,10 +14,10 @@ const ( defaultAcceptBacklog = 1024 ) -const ( - errBrokenPipe = "broken pipe" - errInvalidProtocol = "invalid protocol version" - errGoAway = "stream id overflows, should start a new connection" +var ( + ErrBrokenPipe = errors.New("broken pipe") + ErrInvalidProtocol = errors.New("invalid protocol version") + ErrGoAway = errors.New("stream id overflows, should start a new connection") ) type writeRequest struct { @@ -55,6 +55,12 @@ type Session struct { deadline atomic.Value writes chan writeRequest + writeCtrl chan writeRequest + + rttSn uint32 + rttTest atomic.Value // time.Time + rtt atomic.Value // time.Duration + gotACK chan struct{} } func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { @@ -67,6 +73,10 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { s.bucket = int32(config.MaxReceiveBuffer) s.bucketNotify = make(chan struct{}, 1) s.writes = make(chan writeRequest) + s.writeCtrl = make(chan writeRequest, 4) + + s.rtt.Store(500 * time.Millisecond) + s.gotACK = make(chan struct{}, 1) if client { s.nextStreamID = 1 @@ -82,14 +92,14 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { // OpenStream is used to create a new stream func (s *Session) OpenStream() (*Stream, error) { if s.IsClosed() { - return nil, errors.New(errBrokenPipe) + return nil, ErrBrokenPipe } // generate stream id s.nextStreamIDLock.Lock() if s.goAway > 0 { s.nextStreamIDLock.Unlock() - return nil, errors.New(errGoAway) + return nil, ErrGoAway } s.nextStreamID += 2 @@ -97,7 +107,7 @@ func (s *Session) OpenStream() (*Stream, error) { if sid == sid%2 { // stream-id overflows s.goAway = 1 s.nextStreamIDLock.Unlock() - return nil, errors.New(errGoAway) + return nil, ErrGoAway } s.nextStreamIDLock.Unlock() @@ -128,7 +138,7 @@ func (s *Session) AcceptStream() (*Stream, error) { case <-deadline: return nil, errTimeout case <-s.die: - return nil, errors.New(errBrokenPipe) + return nil, ErrBrokenPipe } } @@ -139,7 +149,7 @@ func (s *Session) Close() (err error) { select { case <-s.die: s.dieLock.Unlock() - return errors.New(errBrokenPipe) + return ErrBrokenPipe default: close(s.die) s.dieLock.Unlock() @@ -189,15 +199,14 @@ func (s *Session) SetDeadline(t time.Time) error { } // notify the session that a stream has closed -func (s *Session) streamClosed(sid uint32) { +func (s *Session) streamClosed(stream *Stream) { s.streamLock.Lock() - if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket - if atomic.AddInt32(&s.bucket, int32(n)) > 0 { - s.notifyBucket() - } - } - delete(s.streams, sid) + delete(s.streams, stream.id) s.streamLock.Unlock() + + if n := stream.recycleTokens(); n > 0 { // return remaining tokens to the bucket + s.returnTokens(n) + } } // returnTokens is called by stream to return token after read @@ -216,7 +225,7 @@ func (s *Session) readFrame(buffer []byte) (f Frame, err error) { } if hdr.Version() != version { - return f, errors.New(errInvalidProtocol) + return f, ErrInvalidProtocol } f.ver = hdr.Version() @@ -244,6 +253,9 @@ func (s *Session) recvLoop() { switch f.cmd { case cmdNOP: + if s.config.EnableStreamBuffer { + s.writeFrameCtrl(newFrame(cmdACK, f.sid), time.After(s.config.KeepAliveTimeout)) + } case cmdSYN: s.streamLock.Lock() if _, ok := s.streams[f.sid]; !ok { @@ -265,14 +277,33 @@ func (s *Session) recvLoop() { case cmdPSH: s.streamLock.Lock() if stream, ok := s.streams[f.sid]; ok { + s.streamLock.Unlock() atomic.AddInt32(&s.bucket, -int32(len(f.data))) stream.pushBytes(f.data) stream.notifyReadEvent() + } else { + s.streamLock.Unlock() + } + case cmdFUL: + s.streamLock.Lock() + if stream, ok := s.streams[f.sid]; ok { + stream.pauseWrite() } s.streamLock.Unlock() + case cmdEMP: + s.streamLock.Lock() + if stream, ok := s.streams[f.sid]; ok { + stream.resumeWrite() + } + s.streamLock.Unlock() + case cmdACK: + if f.sid == atomic.LoadUint32(&s.rttSn) { + rttTest := s.rttTest.Load().(time.Time) + s.rtt.Store(time.Now().Sub(rttTest)) + s.gotACK <- struct{}{} + } default: - s.Close() - return + // nop, for random noise or new feature cmd ID } } else { s.Close() @@ -282,53 +313,107 @@ func (s *Session) recvLoop() { } func (s *Session) keepalive() { - tickerPing := time.NewTicker(s.config.KeepAliveInterval) - tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout) - defer tickerPing.Stop() - defer tickerTimeout.Stop() - for { + + timeout := s.config.KeepAliveTimeout + if !s.config.EnableStreamBuffer && s.config.KeepAliveInterval < s.config.KeepAliveTimeout { + timeout = s.config.KeepAliveInterval + } + + var ping = func(gotACK <-chan struct{}) bool { + ckTimeout := time.NewTimer(timeout) // setup timeout check + s.rttTest.Store(time.Now()) + err := s.writeFrameCtrl(newFrame(cmdNOP, atomic.AddUint32(&s.rttSn, uint32(1))), ckTimeout.C) + if err != nil { // fail to send + s.Close() + return false + } + select { - case <-tickerPing.C: - s.writeFrameInternal(newFrame(cmdNOP, 0), tickerPing.C) - s.notifyBucket() // force a signal to the recvLoop - case <-tickerTimeout.C: + case <-ckTimeout.C: // should never trigger if no timeout if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { - s.Close() + s.Close() // timeout & not any frame recv + return false + } + case <-s.die: + return false + case <-gotACK: // got ACK + } + return true + } + + if !s.config.EnableStreamBuffer { + for { + if !ping(nil) { + return + } + s.notifyBucket() // force a signal to the recvLoop + } + return + } + + if !ping(s.gotACK) { + return + } + t := time.NewTimer(s.config.KeepAliveInterval) + + for { + select { + case <-t.C: // send ping + //t.Stop() + if !ping(s.gotACK) { return } + s.notifyBucket() // force a signal to the recvLoop + t.Reset(s.config.KeepAliveInterval) // setup next ping + case <-s.die: return } } } +func (s *Session) GetRTT() (time.Duration) { + return s.rtt.Load().(time.Duration) +} + func (s *Session) sendLoop() { buf := make([]byte, (1<<16)+headerSize) + send := func(request writeRequest) { + buf[0] = request.frame.ver + buf[1] = request.frame.cmd + binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) + binary.LittleEndian.PutUint32(buf[4:], request.frame.sid) + copy(buf[headerSize:], request.frame.data) + //s.conn.SetWriteDeadline(time.Now().Add(s.config.KeepAliveTimeout)) + n, err := s.conn.Write(buf[:headerSize+len(request.frame.data)]) + + n -= headerSize + if n < 0 { + n = 0 + } + + result := writeResult{ + n: n, + err: err, + } + + request.result <- result + close(request.result) + } + + var req writeRequest for { select { case <-s.die: return - case request := <-s.writes: - buf[0] = request.frame.ver - buf[1] = request.frame.cmd - binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) - binary.LittleEndian.PutUint32(buf[4:], request.frame.sid) - copy(buf[headerSize:], request.frame.data) - n, err := s.conn.Write(buf[:headerSize+len(request.frame.data)]) - - n -= headerSize - if n < 0 { - n = 0 - } - - result := writeResult{ - n: n, - err: err, + case req = <-s.writeCtrl: + case req = <-s.writes: + for len(s.writeCtrl) > 0 { + reqCtrl := <-s.writeCtrl + send(reqCtrl) } - - request.result <- result - close(request.result) } + send(req) } } @@ -340,24 +425,60 @@ func (s *Session) writeFrame(f Frame) (n int, err error) { // internal writeFrame version to support deadline used in keepalive func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time) (int, error) { + req, err := s.writeFrameHalf(f, deadline) + if err != nil { + return 0, err + } + + select { + case result := <-req.result: + return result.n, result.err + case <-deadline: + return 0, errTimeout + case <-s.die: + return 0, ErrBrokenPipe + } +} + + +// must send but nerver block +func (s *Session) writeFrameCtrl(f Frame, deadline <-chan time.Time) error { req := writeRequest{ frame: f, result: make(chan writeResult, 1), } select { case <-s.die: - return 0, errors.New(errBrokenPipe) - case s.writes <- req: + return ErrBrokenPipe + case s.writeCtrl <- req: case <-deadline: - return 0, errTimeout + return errTimeout } + return nil +} +func (s *Session) writeFrameHalf(f Frame, deadline <-chan time.Time) (*writeRequest, error) { + req := writeRequest{ + frame: f, + result: make(chan writeResult, 1), + } select { - case result := <-req.result: - return result.n, result.err - case <-deadline: - return 0, errTimeout case <-s.die: - return 0, errors.New(errBrokenPipe) + return nil, ErrBrokenPipe + case s.writes <- req: + case <-deadline: + return nil, errTimeout + } + return &req, nil +} + +func (s *Session) WriteCustomCMD(cmd byte, bts []byte) (n int, err error) { + if s.IsClosed() { + return 0, ErrBrokenPipe } + f := newFrame(cmd, 0) + f.data = bts + + return s.writeFrame(f) } + diff --git a/session_test.go b/session_test.go index 32fd20b..5e0a65e 100644 --- a/session_test.go +++ b/session_test.go @@ -7,12 +7,24 @@ import ( "io" "math/rand" "net" + "net/http" + _ "net/http/pprof" + //"runtime" + "runtime/pprof" + "bytes" "strings" "sync" + "sync/atomic" "testing" "time" ) +func init() { + go func() { + http.ListenAndServe("localhost:6060", nil) + }() +} + // setupServer starts new server listening on a random localhost port and // returns address of the server, function to stop the server, new client // connection to this server or an error. @@ -29,6 +41,7 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, } go handleConnection(conn) }() + time.Sleep(20 * time.Millisecond) addr = ln.Addr().String() conn, err := net.Dial("tcp", addr) if err != nil { @@ -38,6 +51,12 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, return ln.Addr().String(), func() { ln.Close() }, conn, nil } +func setupServerPipe(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) { + ln, conn := net.Pipe() + go handleConnection(ln) + return "", func() { ln.Close() }, conn, nil +} + func handleConnection(conn net.Conn) { session, _ := Server(conn, nil) for { @@ -87,10 +106,19 @@ func TestEcho(t *testing.T) { } func TestGetDieCh(t *testing.T) { - cs, ss, err := getSmuxStreamPair() + cs, ss, err := getSmuxStreamPair(nil) if err != nil { t.Fatal(err) } + testGetDieCh(t, cs, ss) +} + +func TestGetDieCh2(t *testing.T) { + cs, ss, _ := getSmuxStreamPair(nil) + testGetDieCh(t, cs, ss) +} + +func testGetDieCh(t *testing.T, cs, ss *Stream) { defer ss.Close() dieCh := ss.GetDieCh() go func() { @@ -109,6 +137,17 @@ func TestSpeed(t *testing.T) { t.Fatal(err) } defer stop() + testSpeed(t, cli) +} + +func TestSpeed2(t *testing.T) { + _, stop, cli, _ := setupServerPipe(t) + defer stop() + testSpeed(t, cli) +} + +func testSpeed(t *testing.T, cli net.Conn) { + defer cli.Close() session, _ := Client(cli, nil) stream, _ := session.OpenStream() t.Log(stream.LocalAddr(), stream.RemoteAddr()) @@ -149,29 +188,60 @@ func TestParallel(t *testing.T) { t.Fatal(err) } defer stop() + testParallel(t, cli) +} + +func TestParallel2(t *testing.T) { + _, stop, cli, _ := setupServerPipe(t) + defer stop() + testParallel(t, cli) +} + +func testParallel(t *testing.T, cli net.Conn) { + defer cli.Close() session, _ := Client(cli, nil) par := 1000 messages := 100 + die := make(chan struct{}) var wg sync.WaitGroup - wg.Add(par) for i := 0; i < par; i++ { - stream, _ := session.OpenStream() + stream, err := session.OpenStream() + if err != nil { + dumpGoroutine(t) + t.Fatalf("cannot create stream %v: %v", i, err) + break + } + wg.Add(1) go func(s *Stream) { buf := make([]byte, 20) - for j := 0; j < messages; j++ { - msg := fmt.Sprintf("hello%v", j) - s.Write([]byte(msg)) - if _, err := s.Read(buf); err != nil { - break + + for { // keep read & write untill all stream end + select { + case <-die: + goto END + default: + } + for j := 0; j < messages; j++ { + msg := fmt.Sprintf("hello%v", j) + s.Write([]byte(msg)) + if _, err := s.Read(buf); err != nil { + break + } } } + END: + //<-die s.Close() wg.Done() }(stream) } - t.Log("created", session.NumStreams(), "streams") + t.Log("created", session.NumStreams(), "streams and keep streams do read & write") + time.Sleep(500 * time.Millisecond) + t.Log("kill all", session.NumStreams(), "streams") + close(die) wg.Wait() + t.Log("all", session.NumStreams(), "streams end") session.Close() } @@ -210,20 +280,19 @@ func TestConcurrentClose(t *testing.T) { } defer stop() session, _ := Client(cli, nil) - numStreams := 100 + numStreams := 1000 streams := make([]*Stream, 0, numStreams) - var wg sync.WaitGroup - wg.Add(numStreams) - for i := 0; i < 100; i++ { + for i := 0; i < numStreams; i++ { stream, _ := session.OpenStream() streams = append(streams, stream) } + var wg sync.WaitGroup for _, s := range streams { - stream := s - go func() { + wg.Add(1) + go func(stream *Stream) { stream.Close() wg.Done() - }() + }(s) } session.Close() wg.Wait() @@ -235,6 +304,25 @@ func TestTinyReadBuffer(t *testing.T) { t.Fatal(err) } defer stop() + testTinyReadBuffer(t, cli) +} + +func TestTinyReadBuffer2(t *testing.T) { + _, stop, cli, _ := setupServerPipe(t) + defer stop() + testTinyReadBuffer(t, cli) +} + +func TestTinyReadBuffer3(t *testing.T) { + srv, cli := net.Pipe() + defer srv.Close() + go handleConnection(srv) + testTinyReadBuffer(t, cli) +} + +func testTinyReadBuffer(t *testing.T, cli net.Conn) { + defer cli.Close() + session, _ := Client(cli, nil) stream, _ := session.OpenStream() const N = 100 @@ -287,6 +375,7 @@ func TestKeepAliveTimeout(t *testing.T) { go func() { ln.Accept() }() + defer ln.Close() cli, err := net.Dial("tcp", ln.Addr().String()) if err != nil { @@ -294,6 +383,27 @@ func TestKeepAliveTimeout(t *testing.T) { } defer cli.Close() + testKeepAliveTimeout(t, cli) +} + +func TestKeepAliveTimeout2(t *testing.T) { + c1, c2, err := getTCPConnectionPair() + if err != nil { + t.Fatal(err) + } + defer c1.Close() + defer c2.Close() + testKeepAliveTimeout(t, c1) +} + +func TestKeepAliveTimeout3(t *testing.T) { + c1, c2 := net.Pipe() + defer c1.Close() + defer c2.Close() + testKeepAliveTimeout(t, c1) +} + +func testKeepAliveTimeout(t *testing.T, cli net.Conn) { config := DefaultConfig() config.KeepAliveInterval = time.Second config.KeepAliveTimeout = 2 * time.Second @@ -304,13 +414,13 @@ func TestKeepAliveTimeout(t *testing.T) { } } -type blockWriteConn struct { +type delayWriteConn struct { net.Conn + Delay time.Duration } -func (c *blockWriteConn) Write(b []byte) (n int, err error) { - forever := time.Hour * 24 - time.Sleep(forever) +func (c *delayWriteConn) Write(b []byte) (n int, err error) { + time.Sleep(c.Delay) return c.Conn.Write(b) } @@ -329,8 +439,29 @@ func TestKeepAliveBlockWriteTimeout(t *testing.T) { t.Fatal(err) } defer cli.Close() + testKeepAliveBlockWriteTimeout(t, cli) +} + +func TestKeepAliveBlockWriteTimeout2(t *testing.T) { + c1, c2, err := getTCPConnectionPair() + if err != nil { + t.Fatal(err) + } + defer c1.Close() + defer c2.Close() + testKeepAliveBlockWriteTimeout(t, c1) +} + +func TestKeepAliveBlockWriteTimeout3(t *testing.T) { + c1, c2 := net.Pipe() + defer c1.Close() + defer c2.Close() + testKeepAliveBlockWriteTimeout(t, c1) +} + +func testKeepAliveBlockWriteTimeout(t *testing.T, cli net.Conn) { //when writeFrame block, keepalive in old version never timeout - blockWriteCli := &blockWriteConn{cli} + blockWriteCli := &delayWriteConn{cli, 24 * time.Hour} config := DefaultConfig() config.KeepAliveInterval = time.Second @@ -342,6 +473,43 @@ func TestKeepAliveBlockWriteTimeout(t *testing.T) { } } +func TestKeepAliveDelayWriteTimeout(t *testing.T) { + c1, c2, err := getTCPConnectionPair() + if err != nil { + t.Fatal(err) + } + testKeepAliveDelayWriteTimeout(t, c1, c2) +} + +func TestKeepAliveDelayWriteTimeoutPipe(t *testing.T) { + c1, c2 := net.Pipe() + testKeepAliveDelayWriteTimeout(t, c1, c2) +} + +func testKeepAliveDelayWriteTimeout(t *testing.T, c1 net.Conn, c2 net.Conn) { + defer c1.Close() + defer c2.Close() + + configSrv := DefaultConfig() + configSrv.KeepAliveInterval = 23 * time.Hour // never send ping + configSrv.KeepAliveTimeout = 24 * time.Hour // never check + srv, _ := Server(c2, configSrv) + defer srv.Close() + + // delay 200 ms, old KeepAlive will timeout + delayWriteCli := &delayWriteConn{c1, 200 * time.Millisecond} + //delayWriteCli := &delayWriteConn{c1, 24 * time.Hour} + + config := DefaultConfig() + config.KeepAliveInterval = 200 * time.Millisecond // send @ 200ms + config.KeepAliveTimeout = 300 * time.Millisecond // should check after 300 ms (= 500 ms), not @ 300 ms + session, _ := Client(delayWriteCli, config) + time.Sleep(2 * time.Second) + if session.IsClosed() { + t.Fatal("keepalive-timeout failed, close too quickly") + } +} + func TestServerEcho(t *testing.T) { ln, err := net.Listen("tcp", "localhost:0") if err != nil { @@ -407,6 +575,50 @@ func TestServerEcho(t *testing.T) { } } +func TestServerEcho2(t *testing.T) { + srv, cli := net.Pipe() + defer srv.Close() + defer cli.Close() + + go func() { + session, _ := Server(srv, nil) + if stream, err := session.OpenStream(); err == nil { + const N = 100 + buf := make([]byte, 10) + for i := 0; i < N; i++ { + msg := fmt.Sprintf("hello%v", i) + stream.Write([]byte(msg)) + if n, err := stream.Read(buf); err != nil { + t.Fatal(err) + } else if string(buf[:n]) != msg { + t.Fatal(err) + } + } + stream.Close() + } else { + t.Fatal(err) + } + }() + + if session, err := Client(cli, nil); err == nil { + if stream, err := session.AcceptStream(); err == nil { + buf := make([]byte, 65536) + for { + n, err := stream.Read(buf) + if err != nil { + break + } + stream.Write(buf[:n]) + } + } else { + t.Fatal(err) + } + } else { + t.Fatal(err) + } +} + + func TestSendWithoutRecv(t *testing.T) { _, stop, cli, err := setupServer(t) if err != nil { @@ -499,6 +711,7 @@ func TestRandomFrame(t *testing.T) { t.Fatal(err) } defer stop() + // pure random session, _ := Client(cli, nil) for i := 0; i < 100; i++ { @@ -525,7 +738,7 @@ func TestRandomFrame(t *testing.T) { if err != nil { t.Fatal(err) } - allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP} + allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP, cmdACK, cmdFUL, cmdEMP} session, _ = Client(cli, nil) for i := 0; i < 100; i++ { f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32()) @@ -600,6 +813,7 @@ func TestWriteFrameInternal(t *testing.T) { t.Fatal(err) } defer stop() + // pure random session, _ := Client(cli, nil) for i := 0; i < 100; i++ { @@ -653,7 +867,7 @@ func TestWriteFrameInternal(t *testing.T) { config := DefaultConfig() config.KeepAliveInterval = time.Second config.KeepAliveTimeout = 2 * time.Second - session, _ = Client(&blockWriteConn{cli}, config) + session, _ = Client(&delayWriteConn{cli, 24 * time.Hour}, config) f := newFrame(byte(rand.Uint32()), rand.Uint32()) c := make(chan time.Time) go func() { @@ -664,7 +878,7 @@ func TestWriteFrameInternal(t *testing.T) { close(c) }() _, err = session.writeFrameInternal(f, c) - if err.Error() != errBrokenPipe { + if err.Error() != ErrBrokenPipe.Error() { t.Fatal("write frame with deadline failed", err) } } @@ -676,6 +890,16 @@ func TestReadDeadline(t *testing.T) { t.Fatal(err) } defer stop() + testReadDeadline(t, cli) +} + +func TestReadDeadline2(t *testing.T) { + _, stop, cli, _ := setupServerPipe(t) + defer stop() + testReadDeadline(t, cli) +} + +func testReadDeadline(t *testing.T, cli net.Conn) { session, _ := Client(cli, nil) stream, _ := session.OpenStream() const N = 100 @@ -703,6 +927,16 @@ func TestWriteDeadline(t *testing.T) { t.Fatal(err) } defer stop() + testWriteDeadline(t, cli) +} + +func TestWriteDeadline2(t *testing.T) { + _, stop, cli, _ := setupServerPipe(t) + defer stop() + testWriteDeadline(t, cli) +} + +func testWriteDeadline(t *testing.T, cli net.Conn) { session, _ := Client(cli, nil) stream, _ := session.OpenStream() buf := make([]byte, 10) @@ -719,12 +953,555 @@ func TestWriteDeadline(t *testing.T) { session.Close() } +func TestSlowReadBlocking(t *testing.T) { + c1, c2, err := getTCPConnectionPair() + if err != nil { + t.Fatal(err) + return + } + + testSlowReadBlocking(t, c1, c2) +} + +func TestSlowReadBlocking2(t *testing.T) { + srv, cli := net.Pipe() + defer srv.Close() + defer cli.Close() + + testSlowReadBlocking(t, srv, cli) +} + +func testSlowReadBlocking(t *testing.T, srv net.Conn, cli net.Conn) { + config := &Config{ + KeepAliveInterval: 100 * time.Millisecond, + KeepAliveTimeout: 500 * time.Millisecond, + MaxFrameSize: 4096, + MaxReceiveBuffer: 1 * 1024 * 1024, + EnableStreamBuffer: true, + MaxStreamBuffer: 8192, + BoostTimeout: 0 * time.Millisecond, + } + + go func (conn net.Conn) { + session, _ := Server(conn, config) + for { + if stream, err := session.AcceptStream(); err == nil { + go func(s io.ReadWriteCloser) { + defer s.Close() + buf := make([]byte, 1024 * 1024, 1024 * 1024) + for { + n, err := s.Read(buf) + if err != nil { + return + } + //t.Log("s1", stream.id, "session.bucket", atomic.LoadInt32(&session.bucket), "stream.bucket", atomic.LoadInt32(&stream.bucket), n) + s.Write(buf[:n]) + } + }(stream) + } else { + return + } + } + }(srv) + + session, _ := Client(cli, config) + startNotify := make(chan bool, 1) + flag := int32(1) + var wg sync.WaitGroup + + wg.Add(1) + go func() { // fast write & slow read + defer wg.Done() + + stream, err := session.OpenStream() + if err == nil { + t.Log("fast write stream start...") + defer func() { + stream.Close() + t.Log("fast write stream end...") + }() + + const SIZE = 1 * 1024 // Bytes + const SPDW = 16 * 1024 * 1024 // Bytes/s + const SPDR = 512 * 1024 // Bytes/s + const TestDtW = time.Second / time.Duration(SPDW/SIZE) + const TestDtR = time.Second / time.Duration(SPDR/SIZE) + + var fwg sync.WaitGroup + fwg.Add(1) + go func() { // read = SPDR + defer fwg.Done() + rbuf := make([]byte, SIZE, SIZE) + for atomic.LoadInt32(&flag) > 0 { + stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + if _, err := stream.Read(rbuf); err != nil { + if strings.Contains(err.Error(), "i/o timeout") { + //t.Logf("read block too long: %v", err) + continue + } + break + } + time.Sleep(TestDtR) // slow down read + } + }() + + buf := make([]byte, SIZE, SIZE) + for i := range buf { + buf[i] = byte('-') + } + startNotify <- true + for atomic.LoadInt32(&flag) > 0 { // write = SPDW + stream.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + _, err := stream.Write(buf) + if err != nil { + if strings.Contains(err.Error(), "i/o timeout") { + //t.Logf("write block too long: %v", err) + continue + } + break + } + //t.Log("f2", stream.id, "session.bucket", atomic.LoadInt32(&session.bucket), "stream.bucket", atomic.LoadInt32(&stream.bucket)) + time.Sleep(TestDtW) // slow down write + } + fwg.Wait() + + } else { + t.Fatal(err) + } + }() + + wg.Add(1) + go func() { // normal write, rtt test + defer func() { + session.Close() + wg.Done() + }() + + stream, err := session.OpenStream() + if err == nil { + t.Log("normal stream start...") + defer func() { + atomic.StoreInt32(&flag, int32(0)) + stream.Close() + t.Log("normal stream end...") + }() + + const N = 100 + const TestDt = 50 * time.Millisecond + const TestTimeout = 500 * time.Millisecond + + buf := make([]byte, 12) + <- startNotify + for i := 0; i < N; i++ { + msg := fmt.Sprintf("hello%v", i) + start := time.Now() + + stream.SetWriteDeadline(time.Now().Add(TestTimeout)) + _, err := stream.Write([]byte(msg)) + if err != nil && strings.Contains(err.Error(), "i/o timeout") { + t.Log(stream.id, i, err, + "session.bucket", atomic.LoadInt32(&session.bucket), + "stream.bucket", atomic.LoadInt32(&stream.bucket), + "stream.empflag", atomic.LoadInt32(&stream.empflag), "stream.fulflag", atomic.LoadInt32(&stream.fulflag)) + dumpGoroutine(t) + t.Fatal(err) + return + } + + /*t.Log("[normal]w", stream.id, i, "rtt", time.Since(start), + "stream.bucket", atomic.LoadInt32(&stream.bucket), + "stream.guessNeeded", atomic.LoadInt32(&stream.guessNeeded))*/ + + stream.SetReadDeadline(time.Now().Add(TestTimeout)) + if n, err := stream.Read(buf); err != nil { + t.Log(stream.id, i, err, + "session.bucket", atomic.LoadInt32(&session.bucket), // 0 means MaxReceiveBuffer not enough + "stream.bucket", atomic.LoadInt32(&stream.bucket), // >= MaxStreamBuffer means MaxStreamBuffer not enough or flag not send (bug) + "stream.empflag", atomic.LoadInt32(&stream.empflag), "stream.fulflag", atomic.LoadInt32(&stream.fulflag)) + dumpGoroutine(t) + t.Fatal(stream.id, i, err, "since start", time.Since(start)) + return + } else if string(buf[:n]) != msg { + t.Fatal(err) + } else { + t.Log("[normal]r", stream.id, i, "rtt", time.Since(start), + "stream.bucket", atomic.LoadInt32(&stream.bucket), + "stream.guessNeeded", atomic.LoadInt32(&stream.guessNeeded)) + } + time.Sleep(TestDt) + } + } else { + t.Fatal(err) + } + }() + wg.Wait() +} + +func TestReadStreamAfterStreamCloseButRemainData(t *testing.T) { + s1, s2, err := getSmuxStreamPair(nil) + if err != nil { + t.Fatal(err) + } + testReadStreamAfterStreamCloseButRemainData(t, s1, s2) +} + +func TestReadStreamAfterStreamCloseButRemainDataPipe(t *testing.T) { + s1, s2, err := getSmuxStreamPairPipe(nil) + if err != nil { + t.Fatal(err) + } + testReadStreamAfterStreamCloseButRemainData(t, s1, s2) +} + +func testReadStreamAfterStreamCloseButRemainData(t *testing.T, s1 *Stream, s2 *Stream) { + defer s2.Close() + + const N = 10 + var sent string + var received string + + // send and immediately close + nsent := 0 + for i := 0; i < N; i++ { + msg := fmt.Sprintf("hello%v", i) + sent += msg + n, err := s1.Write([]byte(msg)) + if err != nil { + t.Fatal("cannot write") + } + nsent += n + } + s1.Close() + + // read out all remain data + buf := make([]byte, 10) + nrecv := 0 + for nrecv < nsent { + n, err := s2.Read(buf) + if err == nil { + received += string(buf[:n]) + nrecv += n + } else { + t.Fatal("cannot read remain data", err) + break + } + } + + if sent != received { + t.Fatal("data mimatch") + } + + if _, err := s2.Read(buf); err == nil { + t.Fatal("no error after close and no remain data") + } +} + +func TestReadZeroLengthBuffer(t *testing.T) { + s1, s2, err := getSmuxStreamPair(nil) + if err != nil { + t.Fatal(err) + } + testReadZeroLengthBuffer(t, s1, s2) +} + +func TestReadZeroLengthBuffer2(t *testing.T) { + s1, s2, err := getSmuxStreamPairPipe(nil) + if err != nil { + t.Fatal(err) + } + testReadZeroLengthBuffer(t, s1, s2) +} + +func TestReadZeroLengthBuffer3(t *testing.T) { + s1, s2, err := getTCPConnectionPair() + if err != nil { + t.Fatal(err) + } + testReadZeroLengthBuffer(t, s1, s2) +} + +func testReadZeroLengthBuffer(t *testing.T, srv net.Conn, cli net.Conn) { + gotRet := make(chan bool, 1) + readyRead := make(chan bool, 1) + go func(){ + buf := make([]byte, 0) + close(readyRead) + cli.Read(buf) + close(gotRet) + }() + + <-readyRead + time.Sleep(100 * time.Millisecond) + + select { + case <-gotRet: + default: + t.Fatal("reading zero length buffer should not block") + } + srv.Close() + cli.Close() +} + +func TestWriteStreamRace(t *testing.T) { + config := DefaultConfig() + config.MaxFrameSize = 1500 + config.EnableStreamBuffer = true + config.MaxReceiveBuffer = 16 * 1024 * 1024 + config.MaxStreamBuffer = config.MaxFrameSize * 8 + + s1, s2, err := getSmuxStreamPair(config) + if err != nil { + t.Fatal(err) + } + testWriteStreamRace(t, s1, s2, config.MaxFrameSize) +} + +func TestWriteStreamRace2(t *testing.T) { + config := DefaultConfig() + config.MaxFrameSize = 1500 + config.EnableStreamBuffer = true + config.MaxReceiveBuffer = 16 * 1024 * 1024 + config.MaxStreamBuffer = config.MaxFrameSize * 8 + + s1, s2, err := getSmuxStreamPairPipe(config) + if err != nil { + t.Fatal(err) + } + testWriteStreamRace(t, s1, s2, config.MaxFrameSize) +} + +func TestWriteStreamRaceTCP(t *testing.T) { + s1, s2, err := getTCPConnectionPair() + if err != nil { + t.Fatal(err) + } + testWriteStreamRace(t, s1, s2, 1500) // tcp frame size == 1500 ? +} + +func TestWriteStreamRacePipe(t *testing.T) { // go v1.9.x won't pass + s1, s2 := net.Pipe() + testWriteStreamRace(t, s1, s2, 1500) // tcp frame size == 1500 ? +} + +func testWriteStreamRace(t *testing.T, s1 net.Conn, s2 net.Conn, frameSize int) { + defer s1.Close() + defer s2.Close() + + mkMsg := func(char byte, size int) []byte { + buf := make([]byte, size, size) + for i := range buf { + buf[i] = char + } + return buf + } + + MAXSIZE := frameSize * 4 + testMsg := map[byte][]byte{ + 'a': mkMsg('a', MAXSIZE), + 'b': mkMsg('b', frameSize * 3), + 'c': mkMsg('c', frameSize * 2), + 'd': mkMsg('d', frameSize), + 'e': mkMsg('e', frameSize / 2), + } + + // Parallel Write(), data should not reorder in one Write() call + die := make(chan struct{}) + var wg sync.WaitGroup + for _, msg := range testMsg { + wg.Add(1) + go func(s net.Conn, msg []byte) { + defer wg.Done() + for { // keep write untill all stream end + select { + case <-die: + return + default: + } + s.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + _, err := s.Write(msg) + if err != nil { + //t.Fatal("write data error", err) + return + } + } + }(s1, msg) + } + + // read and check data + const N = 100 * 1000 + buf := make([]byte, MAXSIZE, MAXSIZE) + for i := 0; i < N; i++ { + _, err := io.ReadFull(s2, buf[:1]) + if err != nil { + t.Fatal("cannot read data", err) + break + } + msg := testMsg[buf[0]] + n, err := io.ReadFull(s2, buf[1:len(msg)]) + if err == nil { + if bytes.Compare(buf[:n+1], msg) != 0 { + t.Fatal("data mimatch", n, string(buf[0])) + break + } + } else { + t.Fatal("cannot read data", err) + break + } + } + close(die) + wg.Wait() +} + +func TestSmallBufferReadWrite(t *testing.T) { + c1, c2, err := getTCPConnectionPair() + if err != nil { + t.Fatal(err) + return + } + + testSmallBufferReadWrite(t, c1, c2) +} + +func testSmallBufferReadWrite(t *testing.T, srv net.Conn, cli net.Conn) { + defer srv.Close() + defer cli.Close() + + config := &Config{ + KeepAliveInterval: 10000 * time.Millisecond, + KeepAliveTimeout: 50000 * time.Millisecond, + MaxFrameSize: 1 * 1024, + MaxReceiveBuffer: 2 * 1024, + EnableStreamBuffer: false, + MaxStreamBuffer: 4 * 1024, + BoostTimeout: 0 * time.Millisecond, + } + + + go func (conn net.Conn) { // echo server + session, _ := Server(conn, config) + for { + if stream, err := session.AcceptStream(); err == nil { + go func(s io.ReadWriteCloser) { + defer s.Close() + buf := make([]byte, 1024 * 1024, 1024 * 1024) + for { // just echo + n, err := s.Read(buf) + if err != nil { + return + } + s.Write(buf[:n]) + } + }(stream) + } else { + return + } + } + }(srv) + + var dumpSess = func (t *testing.T, sess *Session) { + sess.streamLock.Lock() + defer sess.streamLock.Unlock() + + t.Logf("================\n") + t.Log("session.bucket", atomic.LoadInt32(&sess.bucket), "session.streams.len", len(sess.streams)) + for _, stream := range sess.streams { + t.Logf("id: %v, addr: %p, bucket: %v, empflag: %v, fulflag: %v\n", + stream.id, stream, atomic.LoadInt32(&stream.bucket), atomic.LoadInt32(&stream.empflag), atomic.LoadInt32(&stream.fulflag)) + } + t.Logf("================\n") + } + + flag := int32(1) + var wg sync.WaitGroup + + var test = func(session *Session) { // fast write & slow read + defer wg.Done() + + stream, err := session.OpenStream() + if err == nil { + t.Log("[stream][start]", stream.id) + defer func() { + stream.Close() + t.Log("[stream][end]", stream.id) + }() + + const SIZE = 8 * 1024 // Bytes + const SPDW = 16 * 1024 * 1024 // Bytes/s + const SPDR = 16 * 1024 * 1024 // Bytes/s + const TestDtW = time.Second / time.Duration(SPDW/SIZE) + const TestDtR = time.Second / time.Duration(SPDR/SIZE) + + var fwg sync.WaitGroup + fwg.Add(1) + go func() { // read + defer fwg.Done() + rbuf := make([]byte, SIZE, SIZE) + for atomic.LoadInt32(&flag) > 0 { + stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + if _, err := stream.Read(rbuf); err != nil { // should not timeout when write speed == read speed + dumpGoroutine(t) + dumpSess(t, session) + t.Fatal("read data error", err) + break + } + time.Sleep(TestDtR) // slow down read + } + }() + + buf := make([]byte, SIZE, SIZE) + for i := range buf { + buf[i] = byte('-') + } + + for atomic.LoadInt32(&flag) > 0 { // write + stream.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + _, err := stream.Write(buf) + if err != nil { + if strings.Contains(err.Error(), "i/o timeout") { + continue + } + dumpSess(t, session) + t.Fatal("write data error", err) + break + } + time.Sleep(TestDtW) // slow down write + } + fwg.Wait() + } else { + t.Fatal(err) + } + } + + session, _ := Client(cli, config) + for i := 0; i < 4; i++ { + wg.Add(1) + go test(session) + } + + time.Sleep(5 * time.Second) + atomic.StoreInt32(&flag, int32(0)) + wg.Wait() +} + func BenchmarkAcceptClose(b *testing.B) { _, stop, cli, err := setupServer(b) if err != nil { b.Fatal(err) } defer stop() + benchmarkAcceptClose(b, cli) +} + +func BenchmarkAcceptClosePipe(b *testing.B) { + _, stop, cli, err := setupServerPipe(b) + if err != nil { + b.Fatal(err) + } + defer stop() + benchmarkAcceptClose(b, cli) +} + +func benchmarkAcceptClose(b *testing.B, cli net.Conn) { session, _ := Client(cli, nil) for i := 0; i < b.N; i++ { if stream, err := session.OpenStream(); err == nil { @@ -734,8 +1511,59 @@ func BenchmarkAcceptClose(b *testing.B) { } } } + func BenchmarkConnSmux(b *testing.B) { - cs, ss, err := getSmuxStreamPair() + config := DefaultConfig() + config.KeepAliveInterval = 5000 * time.Millisecond + config.KeepAliveTimeout = 20000 * time.Millisecond + config.EnableStreamBuffer = false + + cs, ss, err := getSmuxStreamPair(config) + if err != nil { + b.Fatal(err) + } + defer cs.Close() + defer ss.Close() + bench(b, cs, ss) +} + +func BenchmarkConnSmuxEnableStreamToken(b *testing.B) { + config := DefaultConfig() + config.KeepAliveInterval = 50 * time.Millisecond + config.KeepAliveTimeout = 200 * time.Millisecond + config.EnableStreamBuffer = true + + cs, ss, err := getSmuxStreamPair(config) + if err != nil { + b.Fatal(err) + } + defer cs.Close() + defer ss.Close() + bench(b, cs, ss) +} + +func BenchmarkConnSmuxPipe(b *testing.B) { + config := DefaultConfig() + config.KeepAliveInterval = 5000 * time.Millisecond + config.KeepAliveTimeout = 20000 * time.Millisecond + config.EnableStreamBuffer = false + + cs, ss, err := getSmuxStreamPairPipe(config) + if err != nil { + b.Fatal(err) + } + defer cs.Close() + defer ss.Close() + bench(b, cs, ss) +} + +func BenchmarkConnSmuxPipeEnableStreamToken(b *testing.B) { + config := DefaultConfig() + config.KeepAliveInterval = 50 * time.Millisecond + config.KeepAliveTimeout = 200 * time.Millisecond + config.EnableStreamBuffer = true + + cs, ss, err := getSmuxStreamPairPipe(config) if err != nil { b.Fatal(err) } @@ -754,17 +1582,32 @@ func BenchmarkConnTCP(b *testing.B) { bench(b, cs, ss) } -func getSmuxStreamPair() (*Stream, *Stream, error) { +func BenchmarkConnPipe(b *testing.B) { + cs, ss := net.Pipe() + defer cs.Close() + defer ss.Close() + bench(b, cs, ss) +} + +func getSmuxStreamPair(config *Config) (*Stream, *Stream, error) { c1, c2, err := getTCPConnectionPair() if err != nil { return nil, nil, err } + return getSmuxStreamPairInternal(c1, c2, config) +} + +func getSmuxStreamPairPipe(config *Config) (*Stream, *Stream, error) { + c1, c2 := net.Pipe() + return getSmuxStreamPairInternal(c1, c2, config) +} - s, err := Server(c2, nil) +func getSmuxStreamPairInternal(c1, c2 net.Conn, config *Config) (*Stream, *Stream, error) { + s, err := Server(c2, config) if err != nil { return nil, nil, err } - c, err := Client(c1, nil) + c, err := Client(c1, config) if err != nil { return nil, nil, err } @@ -827,7 +1670,10 @@ func bench(b *testing.B, rd io.Reader, wr io.Writer) { defer wg.Done() count := 0 for { - n, _ := rd.Read(buf2) + n, err := rd.Read(buf2) + if err != nil { + b.Fatal("Read()", err) + } count += n if count == 128*1024*b.N { return @@ -839,3 +1685,10 @@ func bench(b *testing.B, rd io.Reader, wr io.Writer) { } wg.Wait() } + +func dumpGoroutine(t *testing.T) { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 2) + t.Log(b.String()) +} + diff --git a/stream.go b/stream.go index 2a2b82f..665447b 100644 --- a/stream.go +++ b/stream.go @@ -7,8 +7,6 @@ import ( "sync" "sync/atomic" "time" - - "github.com/pkg/errors" ) // Stream implements net.Conn @@ -24,6 +22,19 @@ type Stream struct { dieLock sync.Mutex readDeadline atomic.Value writeDeadline atomic.Value + writeLock sync.Mutex + + bucket int32 // token bucket + bucketNotify chan struct{} // used for waiting for tokens + fulflag int32 + empflagLock sync.Mutex + empflag int32 + countRead int32 // for guess read speed + boostTimeout time.Time // for initial boost + guessBucket int32 // for guess needed stream buffer size + + lastWrite atomic.Value + guessNeeded int32 } // newStream initiates a Stream struct @@ -34,6 +45,14 @@ func newStream(id uint32, frameSize int, sess *Session) *Stream { s.frameSize = frameSize s.sess = sess s.die = make(chan struct{}) + + s.bucket = int32(0) + s.bucketNotify = make(chan struct{}, 1) + s.empflag = int32(1) + s.countRead = int32(0) + s.boostTimeout = time.Now().Add(s.sess.config.BoostTimeout) + s.guessBucket = int32(s.sess.config.MaxStreamBuffer) + s.lastWrite.Store(time.Now()) return s } @@ -47,7 +66,7 @@ func (s *Stream) Read(b []byte) (n int, err error) { if len(b) == 0 { select { case <-s.die: - return 0, errors.New(errBrokenPipe) + return 0, ErrBrokenPipe default: return 0, nil } @@ -67,19 +86,24 @@ READ: if n > 0 { s.sess.returnTokens(n) + s.returnTokens(n) return n, nil } else if atomic.LoadInt32(&s.rstflag) == 1 { _ = s.Close() return 0, io.EOF } + if s.sess.config.EnableStreamBuffer { + s.sendResume() + } + select { case <-s.chReadEvent: goto READ case <-deadline: return n, errTimeout case <-s.die: - return 0, errors.New(errBrokenPipe) + return 0, ErrBrokenPipe } } @@ -94,12 +118,24 @@ func (s *Stream) Write(b []byte) (n int, err error) { select { case <-s.die: - return 0, errors.New(errBrokenPipe) + return 0, ErrBrokenPipe default: } + if atomic.LoadInt32(&s.fulflag) == 1 { + select { + case <-s.bucketNotify: + case <-s.die: + return 0, ErrBrokenPipe + case <-deadline: + return 0, errTimeout + } + } + frames := s.split(b, cmdPSH, s.id) sent := 0 + s.writeLock.Lock() + defer s.writeLock.Unlock() for k := range frames { req := writeRequest{ frame: frames[k], @@ -109,7 +145,7 @@ func (s *Stream) Write(b []byte) (n int, err error) { select { case s.sess.writes <- req: case <-s.die: - return sent, errors.New(errBrokenPipe) + return sent, ErrBrokenPipe case <-deadline: return sent, errTimeout } @@ -121,7 +157,7 @@ func (s *Stream) Write(b []byte) (n int, err error) { return sent, result.err } case <-s.die: - return sent, errors.New(errBrokenPipe) + return sent, ErrBrokenPipe case <-deadline: return sent, errTimeout } @@ -136,12 +172,17 @@ func (s *Stream) Close() error { select { case <-s.die: s.dieLock.Unlock() - return errors.New(errBrokenPipe) + if atomic.LoadInt32(&s.rstflag) == 1 { + return nil + } + return ErrBrokenPipe default: close(s.die) s.dieLock.Unlock() - s.sess.streamClosed(s.id) - _, err := s.sess.writeFrame(newFrame(cmdFIN, s.id)) + s.sess.streamClosed(s) + s.writeLock.Lock() + _, err := s.sess.writeFrameInternal(newFrame(cmdFIN, s.id), nil) + s.writeLock.Unlock() return err } } @@ -218,6 +259,35 @@ func (s *Stream) pushBytes(p []byte) { s.bufferLock.Lock() s.buffer.Write(p) s.bufferLock.Unlock() + + if !s.sess.config.EnableStreamBuffer { + return + } + + n := len(p) + used := atomic.AddInt32(&s.bucket, int32(n)) + lastReadOut := atomic.SwapInt32(&s.countRead, int32(0)) // reset read + + // hard limit + if used > int32(s.sess.config.MaxStreamBuffer) { + s.sendPause() + return + } + + if lastReadOut != 0 { + s.lastWrite.Store(time.Now()) + needed := atomic.LoadInt32(&s.guessNeeded) + s.guessBucket = int32((int64(s.guessBucket) * 8 + int64(needed) * 2) / 10) + } + + if used <= s.guessBucket { + s.boostTimeout = time.Now().Add(s.sess.config.BoostTimeout) + return + } + + if time.Now().After(s.boostTimeout) { + s.sendPause() + } } // recycleTokens transform remaining bytes to tokens(will truncate buffer) @@ -259,6 +329,56 @@ func (s *Stream) markRST() { atomic.StoreInt32(&s.rstflag, 1) } +// mark this stream has been pause write +func (s *Stream) pauseWrite() { + atomic.StoreInt32(&s.fulflag, 1) +} + +// mark this stream has been resume write +func (s *Stream) resumeWrite() { + atomic.StoreInt32(&s.fulflag, 0) + select { + case s.bucketNotify <- struct{}{}: + default: + } +} + +// returnTokens is called by stream to return token after read +func (s *Stream) returnTokens(n int) { + if !s.sess.config.EnableStreamBuffer { + return + } + + used := atomic.AddInt32(&s.bucket, -int32(n)) + totalRead := atomic.AddInt32(&s.countRead, int32(n)) + lastWrite, _ := s.lastWrite.Load().(time.Time) + dt := time.Now().Sub(lastWrite) + 1 + rtt := s.sess.rtt.Load().(time.Duration) + needed := totalRead * int32(rtt / dt) + atomic.StoreInt32(&s.guessNeeded, needed) + if used <= 0 || needed >= used { + s.sendResume() + } +} + +// send cmdFUL to pause write +func (s *Stream) sendPause() { + s.empflagLock.Lock() + if atomic.SwapInt32(&s.empflag, 0) == 1 { + s.sess.writeFrameCtrl(newFrame(cmdFUL, s.id), time.After(s.sess.config.KeepAliveTimeout)) + } + s.empflagLock.Unlock() +} + +// send cmdEMP to resume write +func (s *Stream) sendResume() { + s.empflagLock.Lock() + if atomic.SwapInt32(&s.empflag, 1) == 0 { + s.sess.writeFrameHalf(newFrame(cmdEMP, s.id), time.After(s.sess.config.KeepAliveTimeout)) + } + s.empflagLock.Unlock() +} + var errTimeout error = &timeoutError{} type timeoutError struct{}