Skip to content

Commit

Permalink
更新
Browse files Browse the repository at this point in the history
  • Loading branch information
guonaihong committed Jul 20, 2023
1 parent fb34d0a commit a1d56ef
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 88 deletions.
31 changes: 29 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strings"
"time"

"github.com/antlabs/wsutil/bufio2"
"github.com/antlabs/wsutil/bytespool"
"github.com/antlabs/wsutil/enum"
"github.com/antlabs/wsutil/fixedreader"
Expand All @@ -39,7 +40,31 @@ type DialOption struct {
u *url.URL
tlsConfig *tls.Config
dialTimeout time.Duration
config
Config
}

func ClientOpt(opts ...ClientOption) *Config {
var dial DialOption
dial.defaultSetting()
for _, o := range opts {
o(&dial)
}
return &dial.Config
}

func DialWithConfig(rawUrl string, conf *Config) (*Conn, error) {
var dial DialOption
u, err := url.Parse(rawUrl)
if err != nil {
return nil, err
}

dial.u = u
dial.dialTimeout = defaultTimeout
if dial.Header == nil {
dial.Header = make(http.Header)
}
return dial.Dial()
}

// https://datatracker.ietf.org/doc/html/rfc6455#section-4.1
Expand Down Expand Up @@ -231,7 +256,9 @@ func (d *DialOption) Dial() (c *Conn, err error) {
copy(*buf, b)
fr.W = len(b)
}
bufio2.ClearReader(br)
br = nil
}
// fmt.Println(brw.Reader.Buffered())
return newConn(conn, true, d.config, fr, br, bp), nil
return newConn(conn, true, &d.Config, fr, br, bp), nil
}
6 changes: 3 additions & 3 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"github.com/antlabs/wsutil/enum"
)

type config struct {
type Config struct {
Callback
tcpNoDelay bool
replyPing bool // 开启自动回复
Expand All @@ -34,12 +34,12 @@ type config struct {
parseMode parseMode // 解析模式
}

func (c *config) initPayloadSize() int {
func (c *Config) initPayloadSize() int {
return int(1024.0 + float32(enum.MaxFrameHeaderSize)*c.windowsMultipleTimesPayloadSize)
}

// 默认设置
func (c *config) defaultSetting() {
func (c *Config) defaultSetting() {
c.windowsMultipleTimesPayloadSize = 1.0
c.tcpNoDelay = true
c.parseMode = ParseModeWindows
Expand Down
10 changes: 5 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ const (
// var _ net.Conn = (*Conn)(nil)

type Conn struct {
read *bufio.Reader // read 和fr同时只能使用一个
read *bufio.Reader // read 和fr同时只能使用一个
*Config
c net.Conn
client bool
config
once sync.Once
once sync.Once

fr fixedreader.FixedReader
fw fixedwriter.FixedWriter
Expand All @@ -63,13 +63,13 @@ func setNoDelay(c net.Conn, noDelay bool) error {
return nil
}

func newConn(c net.Conn, client bool, conf config, fr fixedreader.FixedReader, read *bufio.Reader, bp bytespool.BytesPool) *Conn {
func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, read *bufio.Reader, bp bytespool.BytesPool) *Conn {
_ = setNoDelay(c, conf.tcpNoDelay)

return &Conn{
c: c,
client: client,
config: conf,
Config: conf,
fr: fr,
read: read,
bp: bp,
Expand Down
79 changes: 2 additions & 77 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,11 @@
package quickws

import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"

"github.com/antlabs/wsutil/bufio2"
"github.com/antlabs/wsutil/bytespool"
"github.com/antlabs/wsutil/enum"
"github.com/antlabs/wsutil/fixedreader"
"github.com/antlabs/wsutil/rsp"
)

var (
Expand All @@ -41,73 +32,7 @@ var (
)

type ConnOption struct {
config
}

func Upgrade(w http.ResponseWriter, r *http.Request, opts ...ServerOption) (c *Conn, err error) {
var conf ConnOption
conf.defaultSetting()
for _, o := range opts {
o(&conf)
}

if ecode, err := checkRequest(r); err != nil {
http.Error(w, err.Error(), ecode)
return nil, err
}

hi, ok := w.(http.Hijacker)
if !ok {
return nil, ErrNotFoundHijacker
}

var read *bufio.Reader
var conn net.Conn
var rw *bufio.ReadWriter
if conf.parseMode == ParseModeWindows {
// 这里不需要rw,直接使用conn
conn, rw, err = hi.Hijack()
bufio2.ClearReadWriter(rw)
rsp.ClearRsp(w)
rw = nil
} else {
var rw *bufio.ReadWriter
conn, rw, err = hi.Hijack()
read = rw.Reader
rw = nil
}
if err != nil {
return nil, err
}

// 是否打开解压缩
// 外层接收压缩, 并且客户端发送扩展过来
if conf.decompression {
conf.decompression = needDecompression(r.Header)
}

buf := bytespool.GetUpgradeRespBytes()

tmpWriter := bytes.NewBuffer((*buf)[:0])
defer func() {
bytespool.PutUpgradeRespBytes(buf)
tmpWriter = nil
}()
if err = prepareWriteResponse(r, tmpWriter, conf.config); err != nil {
return
}

if _, err := conn.Write(tmpWriter.Bytes()); err != nil {
return nil, err
}

var fr fixedreader.FixedReader
var bp bytespool.BytesPool
bp.Init()
if conf.parseMode == ParseModeWindows {
fr.Init(conn, bytespool.GetBytes(conf.initPayloadSize()+enum.MaxFrameHeaderSize))
}
return newConn(conn, false, conf.config, fr, read, bp), nil
Config
}

func writeHeaderKey(w io.Writer, key []byte) (err error) {
Expand All @@ -133,7 +58,7 @@ func writeHeaderVal(w io.Writer, val []byte) (err error) {

// https://datatracker.ietf.org/doc/html/rfc6455#section-4.2.2
// 第5小点
func prepareWriteResponse(r *http.Request, w io.Writer, cnf config) (err error) {
func prepareWriteResponse(r *http.Request, w io.Writer, cnf *Config) (err error) {
if _, err = w.Write(bytesHeaderUpgrade); err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,6 @@ func Test_Upgrade(t *testing.T) {
}

var out bytes.Buffer
prepareWriteResponse(r, &out, config{})
prepareWriteResponse(r, &out, &Config{})
fmt.Printf("%s\n %d", out.Bytes(), out.Len())
}
114 changes: 114 additions & 0 deletions upgrade.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2021-2023 antlabs. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package quickws

import (
"bufio"
"bytes"
"net"
"net/http"

"github.com/antlabs/wsutil/bufio2"
"github.com/antlabs/wsutil/bytespool"
"github.com/antlabs/wsutil/enum"
"github.com/antlabs/wsutil/fixedreader"
"github.com/antlabs/wsutil/rsp"
)

type UpgradeServer struct {
config Config
}

func NewUpgrade(opts ...ServerOption) *UpgradeServer {
var conf ConnOption
conf.defaultSetting()
for _, o := range opts {
o(&conf)
}
return &UpgradeServer{config: conf.Config}
}

func (u *UpgradeServer) Upgrade(w http.ResponseWriter, r *http.Request) (c *Conn, err error) {
return upgradeInner(w, r, &u.config)
}

func Upgrade(w http.ResponseWriter, r *http.Request, opts ...ServerOption) (c *Conn, err error) {
var conf ConnOption
conf.defaultSetting()
for _, o := range opts {
o(&conf)
}
return upgradeInner(w, r, &conf.Config)
}

func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config) (c *Conn, err error) {
if ecode, err := checkRequest(r); err != nil {
http.Error(w, err.Error(), ecode)
return nil, err
}

hi, ok := w.(http.Hijacker)
if !ok {
return nil, ErrNotFoundHijacker
}

var read *bufio.Reader
var conn net.Conn
var rw *bufio.ReadWriter
if conf.parseMode == ParseModeWindows {
// 这里不需要rw,直接使用conn
conn, rw, err = hi.Hijack()
bufio2.ClearReadWriter(rw)
rsp.ClearRsp(w)
rw = nil
} else {
var rw *bufio.ReadWriter
conn, rw, err = hi.Hijack()
read = rw.Reader
rw = nil
}
if err != nil {
return nil, err
}

// 是否打开解压缩
// 外层接收压缩, 并且客户端发送扩展过来
if conf.decompression {
conf.decompression = needDecompression(r.Header)
}

buf := bytespool.GetUpgradeRespBytes()

tmpWriter := bytes.NewBuffer((*buf)[:0])
defer func() {
bytespool.PutUpgradeRespBytes(buf)
tmpWriter = nil
}()
if err = prepareWriteResponse(r, tmpWriter, conf); err != nil {
return
}

if _, err := conn.Write(tmpWriter.Bytes()); err != nil {
return nil, err
}

var fr fixedreader.FixedReader
var bp bytespool.BytesPool
bp.Init()
if conf.parseMode == ParseModeWindows {
fr.Init(conn, bytespool.GetBytes(conf.initPayloadSize()+enum.MaxFrameHeaderSize))
}
return newConn(conn, false, conf, fr, read, bp), nil
}

0 comments on commit a1d56ef

Please sign in to comment.