Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sec web socket protocol #11

Merged
merged 3 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,35 @@ var (
)

type DialOption struct {
Header http.Header
u *url.URL
tlsConfig *tls.Config
dialTimeout time.Duration
Header http.Header
u *url.URL
tlsConfig *tls.Config
dialTimeout time.Duration
bindClientHttpHeader *http.Header // 握手成功之后, 客户端获取http.Header,
Config
}

func ClientOptionToConf(opts ...ClientOption) *Config {
func ClientOptionToConf(opts ...ClientOption) *DialOption {
var dial DialOption
dial.defaultSetting()
for _, o := range opts {
o(&dial)
}
return &dial.Config
return &dial
}

func DialConf(rawUrl string, conf *Config) (*Conn, error) {
var dial DialOption
func DialConf(rawUrl string, conf *DialOption) (*Conn, error) {
u, err := url.Parse(rawUrl)
if err != nil {
return nil, err
}

dial.u = u
dial.dialTimeout = defaultTimeout
if dial.Header == nil {
dial.Header = make(http.Header)
conf.u = u
conf.dialTimeout = defaultTimeout
if conf.Header == nil {
conf.Header = make(http.Header)
}
return dial.Dial()
return conf.Dial()
}

// https://datatracker.ietf.org/doc/html/rfc6455#section-4.1
Expand Down Expand Up @@ -222,6 +222,10 @@ func (d *DialOption) Dial() (c *Conn, err error) {
return nil, err
}

if d.bindClientHttpHeader != nil {
*d.bindClientHttpHeader = rsp.Header.Clone()
}

cd := maybeCompressionDecompression(rsp.Header)
if d.decompression {
d.decompression = cd
Expand Down
50 changes: 50 additions & 0 deletions client_option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,54 @@ func Test_ClientOption(t *testing.T) {
t.Error("not run server:method fail")
}
})

t.Run("6.1 Dial: WithClientBindHTTPHeader", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := Upgrade(w, r)
if err != nil {
t.Error(err)
}
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
h := make(http.Header)
con, err := Dial(url, WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{
"Sec-WebSocket-Protocol": []string{"token"},
}))
if err != nil {
t.Error(err)
}
defer con.Close()

if h["Sec-Websocket-Protocol"][0] != "token" {
t.Error("header fail")
}
})

t.Run("6.2 Dial: WithClientBindHTTPHeader", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := Upgrade(w, r)
if err != nil {
t.Error(err)
}
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
h := make(http.Header)
con, err := DialConf(url, ClientOptionToConf(WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{
"Sec-WebSocket-Protocol": []string{"token"},
})))
if err != nil {
t.Error(err)
}
defer con.Close()

if h["Sec-Websocket-Protocol"][0] != "token" {
t.Error("header fail")
}
})
}
17 changes: 12 additions & 5 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,45 @@ import (

type ClientOption func(*DialOption)

// 配置tls.config
// 1.配置tls.config
func WithClientTLSConfig(tls *tls.Config) ClientOption {
return func(o *DialOption) {
o.tlsConfig = tls
}
}

// 配置http.Header
// 2.配置http.Header
func WithClientHTTPHeader(h http.Header) ClientOption {
return func(o *DialOption) {
o.Header = h
}
}

// 配置握手时的timeout
// 3.配置握手时的timeout
func WithClientDialTimeout(t time.Duration) ClientOption {
return func(o *DialOption) {
o.dialTimeout = t
}
}

// 配置压缩
// 4.配置压缩
func WithClientCompression() ClientOption {
return func(o *DialOption) {
o.compression = true
}
}

// 配置压缩和解压缩
// 5.配置压缩和解压缩
func WithClientDecompressAndCompress() ClientOption {
return func(o *DialOption) {
o.compression = true
o.decompression = true
}
}

// 6.获取http header
func WithClientBindHTTPHeader(h *http.Header) ClientOption {
return func(o *DialOption) {
o.bindClientHttpHeader = h
}
}
23 changes: 17 additions & 6 deletions server_handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ import (
)

var (
ErrNotFoundHijacker = errors.New("not found Hijacker")
bytesHeaderUpgrade = []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept:")
bytesHeaderExtensions = []byte("Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n")
bytesCRLF = []byte("\r\n")
bytesColon = []byte(": ")
ErrNotFoundHijacker = errors.New("not found Hijacker")
bytesHeaderUpgrade = []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept:")
bytesHeaderExtensions = []byte("Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n")
bytesCRLF = []byte("\r\n")
strGetSecWebSocketProtocolKey = "Sec-WebSocket-Protocol"
bytesPutSecWebSocketProtocolKey = []byte("Sec-WebSocket-Protocol: ")
)

type ConnOption struct {
Expand Down Expand Up @@ -67,6 +68,17 @@ func prepareWriteResponse(r *http.Request, w io.Writer, cnf *Config) (err error)
}
}

v = r.Header.Get(strGetSecWebSocketProtocolKey)
if len(v) > 0 {
if _, err = w.Write(bytesPutSecWebSocketProtocolKey); err != nil {
return
}

if err = writeHeaderVal(w, StringToBytes(v)); err != nil {
return err
}
}

_, err = w.Write(bytesCRLF)
return err
}
Expand Down Expand Up @@ -111,7 +123,6 @@ func checkRequest(r *http.Request) (ecode int, err error) {
return http.StatusUpgradeRequired, ErrSecWebSocketVersion
}

// TODO Sec-WebSocket-Protocol
// TODO Sec-WebSocket-Extensions
return 0, nil
}
Loading