Skip to content

Commit

Permalink
+WriteMessageDelay接口(实验)
Browse files Browse the repository at this point in the history
  • Loading branch information
guonaihong committed Aug 6, 2023
1 parent 086e69a commit c3995c1
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 7 deletions.
27 changes: 26 additions & 1 deletion common_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
// limitations under the License.
package quickws

import "unicode/utf8"
import (
"time"
"unicode/utf8"
)

// 1. callback
// 配置客户端callback
Expand Down Expand Up @@ -200,3 +203,25 @@ func WithClientBufioMultipleTimesPayloadSize(mt float32) ClientOption {
o.bufioMultipleTimesPayloadSize = mt
}
}

// 13. 配置延迟发送
// 配置延迟最大发送时间
func WithServerMaxDelayWriteDuration(d time.Duration) ServerOption {
return func(o *ConnOption) {
o.maxDelayWriteDuration = d
}
}

// 14. 配置最大延迟个数
func WithServerMaxDelayWriteNum(n int32) ServerOption {
return func(o *ConnOption) {
o.maxDelayWriteNum = n
}
}

// 15. 配置延迟包的初始化buffer大小
func WithServerDelayWriteInitBufferSize(n int32) ServerOption {
return func(o *ConnOption) {
o.delayWriteInitBufferSize = n
}
}
12 changes: 9 additions & 3 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ type Config struct {
disableBufioClearHack bool // 关闭bufio的clear hack优化
utf8Check func([]byte) bool // utf8检查
readTimeout time.Duration
windowsMultipleTimesPayloadSize float32 // 设置几倍的payload大小
bufioMultipleTimesPayloadSize float32 // 设置几倍的payload大小
parseMode parseMode // 解析模式
windowsMultipleTimesPayloadSize float32 // 设置几倍的payload大小
bufioMultipleTimesPayloadSize float32 // 设置几倍的payload大小
parseMode parseMode // 解析模式
maxDelayWriteNum int32 // 最大延迟包的个数, 默认值为10
delayWriteInitBufferSize int32 // 延迟写入的初始缓冲区大小, 默认值是8k
maxDelayWriteDuration time.Duration // 最大延迟时间, 默认值是10ms
}

func (c *Config) initPayloadSize() int {
Expand All @@ -41,7 +44,10 @@ func (c *Config) initPayloadSize() int {

// 默认设置
func (c *Config) defaultSetting() {
c.maxDelayWriteNum = 10
c.windowsMultipleTimesPayloadSize = 1.0
c.delayWriteInitBufferSize = 8 * 1024
c.maxDelayWriteDuration = 10 * time.Millisecond
c.tcpNoDelay = true
c.parseMode = ParseModeWindows
// 对于text消息,默认不检查text 是否为utf8字符
Expand Down
120 changes: 117 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
"unsafe"

Expand All @@ -42,16 +43,27 @@ const (

// var _ net.Conn = (*Conn)(nil)

type delayWrite struct {
delayNum int32 // 实验某些特性加的字段
delayMu sync.Mutex // 实验某些特性加的字段
delayBuf *bytes.Buffer // 实验某些特性加的字段
delayTimeout *time.Timer // 实验某些特性加的字段
delayErr error
}

type Conn struct {
read *bufio.Reader // read 和fr同时只能使用一个
closed int32
read *bufio.Reader // read 和fr同时只能使用一个
*Config
c net.Conn
client bool
once sync.Once

fr fixedreader.FixedReader
fw fixedwriter.FixedWriter
bp bytespool.BytesPool
bp bytespool.BytesPool // 实验某些特性加的字段

delayWrite
}

func setNoDelay(c net.Conn, noDelay bool) error {
Expand All @@ -68,14 +80,16 @@ func setNoDelay(c net.Conn, noDelay bool) error {
func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, read *bufio.Reader, bp bytespool.BytesPool) *Conn {
_ = setNoDelay(c, conf.tcpNoDelay)

return &Conn{
con := &Conn{
c: c,
client: client,
Config: conf,
fr: fr,
read: read,
bp: bp,
}

return con
}

func (c *Conn) writeErrAndOnClose(code StatusCode, userErr error) error {
Expand Down Expand Up @@ -341,6 +355,10 @@ func (w *wrapBuffer) Close() error {
}

func (c *Conn) WriteMessage(op Opcode, writeBuf []byte) (err error) {
if atomic.LoadInt32(&c.closed) == 1 {
return ErrClosed
}

if op == opcode.Text {
if !c.utf8Check(writeBuf) {
return ErrTextNotUTF8
Expand Down Expand Up @@ -489,6 +507,102 @@ func (c *Conn) Close() (err error) {
c.once.Do(func() {
c.bp.Free()
err = c.c.Close()
if c.delayTimeout != nil {
c.delayTimeout.Stop()
c.delayMu.Lock()
c.delayBuf = nil
c.delayMu.Unlock()
}
atomic.StoreInt32(&c.closed, 1)
})
return
}

func (c *Conn) writerDelayBufSafe() {
c.delayMu.Lock()
c.delayErr = c.writerDelayBufInner()
c.delayMu.Unlock()
return
}

func (c *Conn) writerDelayBufInner() (err error) {
if c.delayBuf == nil {
return nil
}
_, err = c.c.Write(c.delayBuf.Bytes())
if c.delayTimeout != nil {
c.delayTimeout.Reset(c.maxDelayWriteDuration)
}
c.delayNum = 0
c.delayBuf.Reset()
return
}

// 延迟写消息
// 1. 如果缓存的消息超过了多少条数
// 2. 如果缓存的消费超过了多久的时间
func (c *Conn) WriteMessageDelay(op Opcode, writeBuf []byte) (err error) {
if atomic.LoadInt32(&c.closed) == 1 {
return ErrClosed
}

if op == opcode.Text {
if !c.utf8Check(writeBuf) {
return ErrTextNotUTF8
}
}

rsv1 := c.compression && (op == opcode.Text || op == opcode.Binary)
if rsv1 {
var out wrapBuffer
w := compressNoContextTakeover(&out, defaultCompressionLevel)
if _, err = io.Copy(w, bytes.NewReader(writeBuf)); err != nil {
return
}

if err = w.Close(); err != nil {
return
}
writeBuf = out.Bytes()
}

// 初始化缓存
if c.delayBuf == nil && c.delayWriteInitBufferSize > 0 {

c.delayMu.Lock()
// TODO sync.Pool管理下, 如果size是1k 2k 3k
delayBuf := make([]byte, 0, c.delayWriteInitBufferSize)
c.delayBuf = bytes.NewBuffer(delayBuf)
c.delayMu.Unlock()
}
// 初始化定时器
if c.delayTimeout == nil && c.maxDelayWriteDuration > 0 {
c.delayTimeout = time.AfterFunc(c.maxDelayWriteDuration, c.writerDelayBufSafe)
}

// 缓存的消息超过最大值, 则直接写入
c.delayMu.Lock()
if c.delayNum+1 == c.maxDelayWriteNum {
err = c.writerDelayBufInner()
c.delayMu.Unlock()
return err
}
c.delayMu.Unlock()

maskValue := uint32(0)
if c.client {
maskValue = rand.Uint32()
}

// go func() {
// 为了平衡生产者,消费者的速度,这里不再使用协程

c.delayMu.Lock()
if c.delayBuf != nil {
frame.WriteFrameToBytes(c.delayBuf, writeBuf, true, rsv1, c.client, op, maskValue)

Check failure on line 602 in conn.go

View workflow job for this annotation

GitHub Actions / Go 1.20 sample

undefined: frame.WriteFrameToBytes
}
c.delayNum++
c.delayMu.Unlock()
// }()
return nil
}
3 changes: 3 additions & 0 deletions err.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ package quickws
import "errors"

var (
// conn已经被关闭
ErrClosed = errors.New("closed")

ErrWrongStatusCode = errors.New("Wrong status code")
ErrUpgradeFieldValue = errors.New("The value of the upgrade field is not 'websocket'")
ErrConnectionFieldValue = errors.New("The value of the connection field is not 'upgrade'")
Expand Down

0 comments on commit c3995c1

Please sign in to comment.