diff --git a/README.md b/README.md index 7f832ea..472f2e2 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ quickws是一个高性能的websocket库 * [配置header](#配置header) * [配置握手时的超时时间](#配置握手时的超时时间) * [配置自动回复ping消息](#配置自动回复ping消息) + * [配置socks5代理](#配置socks5代理) * [服务配置参数](#服务端配置) * [配置服务自动回复ping消息](#配置服务自动回复ping消息) ## 注意⚠️ @@ -196,6 +197,19 @@ func main() { quickws.Dial("ws://127.0.0.1:12345/test", quickws.WithClientReplyPing()) } ``` +#### 配置socks5代理 +```go +import( + "github.com/antlabs/quickws" + "golang.org/x/net/proxy" +) + +func main() { + quickws.Dial("ws://127.0.0.1:12345", quickws.WithClientDialFunc(func() (quickws.Dialer, error) { + return proxy.SOCKS5("tcp", "socks5代理服务地址", nil, nil) + })) +} +``` ### 服务端配置参数 #### 配置服务自动回复ping消息 ```go diff --git a/autobahn/autobahn-server.go b/autobahn/autobahn-server.go index 032b7d6..b775d80 100644 --- a/autobahn/autobahn-server.go +++ b/autobahn/autobahn-server.go @@ -64,7 +64,7 @@ func main() { mux := &http.ServeMux{} mux.HandleFunc("/autobahn", echo) - rawTCP, err := net.Listen("tcp", "localhost:9001") + rawTCP, err := net.Listen("tcp", ":9001") if err != nil { fmt.Println("Listen fail:", err) return diff --git a/client.go b/client.go index 9729a2a..d1a6bfc 100644 --- a/client.go +++ b/client.go @@ -183,9 +183,18 @@ func (d *DialOption) Dial() (c *Conn, err error) { return nil, err } + var conn net.Conn begin := time.Now() // conn, err := net.DialTimeout("tcp", d.u.Host /* TODO 加端号*/, d.dialTimeout) - conn, err := net.Dial("tcp", d.u.Host /* TODO 加端号*/) + if d.dialFunc == nil { + conn, err = net.Dial("tcp", d.u.Host /* TODO 加端号*/) + } else { + dialInterface, err := d.dialFunc() + if err != nil { + return nil, err + } + conn, err = dialInterface.Dial("tcp", d.u.Host) + } if err != nil { return nil, err } diff --git a/client_option_test.go b/client_option_test.go index e76cd2d..82f38e9 100644 --- a/client_option_test.go +++ b/client_option_test.go @@ -15,13 +15,19 @@ package quickws import ( + "bytes" "crypto/tls" + "encoding/binary" + "io" + "net" "net/http" "net/http/httptest" "strings" "sync/atomic" "testing" "time" + + "golang.org/x/net/proxy" ) func Test_ClientOption(t *testing.T) { @@ -156,4 +162,204 @@ func Test_ClientOption(t *testing.T) { t.Error("header fail") } }) + + t.Run("18 Dial: WithClientDialFunc.1", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := Upgrade(w, r, WithServerOnMessageFunc(func(c *Conn, o Opcode, b []byte) { + c.WriteMessage(o, b) + c.Close() + })) + if err != nil { + t.Error(err) + } + + conn.StartReadLoop() + })) + + proxyAddr, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Error(err) + } + defer ts.Close() + + go func() { + newConn, err := proxyAddr.Accept() + if err != nil { + t.Error(err) + } + + newConn.SetDeadline(time.Now().Add(30 * time.Second)) + + buf := make([]byte, 128) + if _, err := io.ReadFull(newConn, buf[:3]); err != nil { + t.Errorf("read failed: %v", err) + return + } + + // socks version 5, 1 authentication method, no auth + if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) { + t.Errorf("read %x, want %x", buf[:len(want)], want) + } + + // socks version 5, connect command, reserved, ipv4 address, port 80 + if _, err := newConn.Write([]byte{5, 0}); err != nil { + t.Errorf("write failed: %v", err) + return + } + + // ver cmd rsv atyp dst.addr dst.port + if _, err := io.ReadFull(newConn, buf[:10]); err != nil { + t.Errorf("read failed: %v", err) + return + } + if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) { + t.Errorf("read %x, want %x", buf[:len(want)], want) + return + } + buf[1] = 0 + if _, err := newConn.Write(buf[:10]); err != nil { + t.Errorf("write failed: %v", err) + return + } + + // 提取ip + ip := net.IP(buf[4:8]) + port := binary.BigEndian.Uint16(buf[8:10]) + + c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)}) + if err != nil { + t.Errorf("dial failed; %v", err) + return + } + defer c2.Close() + done := make(chan struct{}) + go func() { + io.Copy(newConn, c2) + close(done) + }() + io.Copy(c2, newConn) + <-done + }() + + got := make([]byte, 0, 128) + url := strings.ReplaceAll(ts.URL, "http", "ws") + c, err := Dial(url, WithClientDialFunc(func() (Dialer, error) { + return proxy.SOCKS5("tcp", proxyAddr.Addr().String(), nil, nil) + }), WithClientOnMessageFunc(func(c *Conn, o Opcode, b []byte) { + got = append(got, b...) + c.Close() + })) + if err != nil { + t.Error(err) + } + + data := []byte("hello world") + c.WriteMessage(Binary, data) + c.ReadLoop() + + t.Log("got", string(got), "want", string(data)) + if !bytes.Equal(got, data) { + t.Errorf("got %s, want %s", got, data) + } + }) + + t.Run("18 Dial: WithClientDialFunc.2", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := Upgrade(w, r, WithServerOnMessageFunc(func(c *Conn, o Opcode, b []byte) { + c.WriteMessage(o, b) + c.Close() + })) + if err != nil { + t.Error(err) + } + + conn.StartReadLoop() + })) + + proxyAddr, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Error(err) + } + defer ts.Close() + + go func() { + newConn, err := proxyAddr.Accept() + if err != nil { + t.Error(err) + } + + newConn.SetDeadline(time.Now().Add(30 * time.Second)) + + buf := make([]byte, 128) + if _, err := io.ReadFull(newConn, buf[:3]); err != nil { + t.Errorf("read failed: %v", err) + return + } + + // socks version 5, 1 authentication method, no auth + if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) { + t.Errorf("read %x, want %x", buf[:len(want)], want) + } + + // socks version 5, connect command, reserved, ipv4 address, port 80 + if _, err := newConn.Write([]byte{5, 0}); err != nil { + t.Errorf("write failed: %v", err) + return + } + + // ver cmd rsv atyp dst.addr dst.port + if _, err := io.ReadFull(newConn, buf[:10]); err != nil { + t.Errorf("read failed: %v", err) + return + } + if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) { + t.Errorf("read %x, want %x", buf[:len(want)], want) + return + } + buf[1] = 0 + if _, err := newConn.Write(buf[:10]); err != nil { + t.Errorf("write failed: %v", err) + return + } + + // 提取ip + ip := net.IP(buf[4:8]) + port := binary.BigEndian.Uint16(buf[8:10]) + + c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)}) + if err != nil { + t.Errorf("dial failed; %v", err) + return + } + defer c2.Close() + done := make(chan struct{}) + go func() { + io.Copy(newConn, c2) + close(done) + }() + io.Copy(c2, newConn) + <-done + }() + + got := make([]byte, 0, 128) + url := strings.ReplaceAll(ts.URL, "http", "ws") + c, err := DialConf(url, ClientOptionToConf(WithClientDialFunc(func() (Dialer, error) { + return proxy.SOCKS5("tcp", proxyAddr.Addr().String(), nil, nil) + }), WithClientOnMessageFunc(func(c *Conn, o Opcode, b []byte) { + got = append(got, b...) + c.Close() + }))) + if err != nil { + t.Error(err) + } + + data := []byte("hello world") + c.WriteMessage(Binary, data) + c.ReadLoop() + + t.Log("got", string(got), "want", string(data)) + if !bytes.Equal(got, data) { + t.Errorf("got %s, want %s", got, data) + } + }) } diff --git a/common_options.go b/common_options.go index 27eb848..a0e2933 100644 --- a/common_options.go +++ b/common_options.go @@ -300,3 +300,10 @@ func WithClientOnCloseFunc(onClose func(c *Conn, err error)) ClientOption { o.Callback = OnCloseFunc(onClose) } } + +// 18. 配置新的dial函数 +func WithClientDialFunc(dialFunc func() (Dialer, error)) ClientOption { + return func(o *DialOption) { + o.dialFunc = dialFunc + } +} diff --git a/config.go b/config.go index ccbd334..2c9b289 100644 --- a/config.go +++ b/config.go @@ -15,11 +15,16 @@ package quickws import ( + "net" "time" "github.com/antlabs/wsutil/enum" ) +type Dialer interface { + Dial(network, addr string) (c net.Conn, err error) +} + type Config struct { Callback tcpNoDelay bool @@ -37,6 +42,7 @@ type Config struct { delayWriteInitBufferSize int32 // 延迟写入的初始缓冲区大小, 默认值是8k maxDelayWriteDuration time.Duration // 最大延迟时间, 默认值是10ms subProtocols []string // 设置支持的子协议 + dialFunc func() (Dialer, error) } func (c *Config) initPayloadSize() int { diff --git a/go.mod b/go.mod index c8fbfae..ad9d6d2 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/antlabs/quickws go 1.20 -require github.com/antlabs/wsutil v0.1.2 +require ( + github.com/antlabs/wsutil v0.1.2 + golang.org/x/net v0.19.0 +) diff --git a/go.sum b/go.sum index e9aaac5..b5dc98b 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ github.com/antlabs/wsutil v0.1.2 h1:8H6E0eMJ2Wp0qi9YGDeyG3DlIfIZncw2NSScC5bYSBQ= github.com/antlabs/wsutil v0.1.2/go.mod h1:7ec5eUM7nmKW+Oi6F1I58iatOeL9k+yIsfOh1zh910g= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=