diff --git a/kit/stub/option.go b/kit/stub/option.go index e29508c0..7f4a5959 100644 --- a/kit/stub/option.go +++ b/kit/stub/option.go @@ -1,9 +1,7 @@ package stub import ( - "fmt" "io" - "strings" "time" "github.com/clubpay/ronykit/kit" @@ -25,7 +23,7 @@ type config struct { tp kit.TracePropagator readTimeout, writeTimeout, dialTimeout time.Duration - httpProxyConfig *httpproxy.Config + proxy *httpproxy.Config dialFunc fasthttp.DialFunc } @@ -87,29 +85,25 @@ func WithTracePropagator(tp kit.TracePropagator) Option { // WithHTTPProxy returns an Option that sets the dialer to the provided HTTP proxy. // example formats: // -// http://localhost:9050 -// http://username:password@localhost:9050 -// https://localhost:9050 -func WithHTTPProxy(url string, timeout time.Duration) Option { +// localhost:9050 +// username:password@localhost:9050 +// localhost:9050 +func WithHTTPProxy(proxyURL string, timeout time.Duration) Option { return func(cfg *config) { - cfg.httpProxyConfig = httpproxy.FromEnvironment() - switch { - default: - panic(fmt.Errorf("unsupported proxy scheme: %s", url)) - case strings.HasPrefix(url, "https://"): - cfg.httpProxyConfig.HTTPSProxy = url - case strings.HasPrefix(url, "http://"): - cfg.httpProxyConfig.HTTPProxy = url - } - - cfg.dialFunc = fasthttpproxy.FasthttpHTTPDialerTimeout(url, timeout) + cfg.proxy = httpproxy.FromEnvironment() + cfg.proxy.HTTPProxy = proxyURL + cfg.proxy.HTTPSProxy = proxyURL + cfg.dialFunc = fasthttpproxy.FasthttpHTTPDialerTimeout(proxyURL, timeout) } } // WithSocksProxy returns an Option that sets the dialer to the provided SOCKS5 proxy. -// example format: socks5://localhost:9050 -func WithSocksProxy(url string) Option { +// example format: localhost:9050 +func WithSocksProxy(proxyURL string) Option { return func(cfg *config) { - cfg.dialFunc = fasthttpproxy.FasthttpSocksDialer(url) + cfg.proxy = httpproxy.FromEnvironment() + cfg.proxy.HTTPProxy = proxyURL + cfg.proxy.HTTPSProxy = proxyURL + cfg.dialFunc = fasthttpproxy.FasthttpSocksDialer(proxyURL) } } diff --git a/kit/stub/stub.go b/kit/stub/stub.go index df814179..55383665 100644 --- a/kit/stub/stub.go +++ b/kit/stub/stub.go @@ -21,7 +21,7 @@ type Stub struct { cfg config r *reflector.Reflector - httpC fasthttp.Client + httpC *fasthttp.Client } func New(hostPort string, opts ...Option) *Stub { @@ -37,19 +37,24 @@ func New(hostPort string, opts ...Option) *Stub { opt(&cfg) } - return &Stub{ - cfg: cfg, - r: reflector.New(), - httpC: fasthttp.Client{ - Name: cfg.name, - ReadTimeout: cfg.readTimeout, - WriteTimeout: cfg.writeTimeout, - Dial: cfg.dialFunc, - TLSConfig: &tls.Config{ - InsecureSkipVerify: cfg.skipVerifyTLS, //nolint:gosec - }, + httpC := &fasthttp.Client{ + Name: cfg.name, + ReadTimeout: cfg.readTimeout, + WriteTimeout: cfg.writeTimeout, + TLSConfig: &tls.Config{ + InsecureSkipVerify: cfg.skipVerifyTLS, //nolint:gosec }, } + + if cfg.dialFunc != nil { + httpC.Dial = cfg.dialFunc + } + + return &Stub{ + cfg: cfg, + r: reflector.New(), + httpC: httpC, + } } func HTTP(rawURL string, opts ...Option) (*RESTCtx, error) { @@ -79,7 +84,7 @@ func HTTP(rawURL string, opts ...Option) (*RESTCtx, error) { func (s *Stub) REST(opt ...RESTOption) *RESTCtx { ctx := &RESTCtx{ - c: &s.httpC, + c: s.httpC, r: s.r, handlers: map[int]RESTResponseHandler{}, uri: fasthttp.AcquireURI(), @@ -108,28 +113,29 @@ func (s *Stub) REST(opt ...RESTOption) *RESTCtx { } func (s *Stub) Websocket(opts ...WebsocketOption) *WebsocketCtx { - var proxyFunc func(req *http.Request) (*url.URL, error) - if s.cfg.httpProxyConfig != nil { - fn := s.cfg.httpProxyConfig.ProxyFunc() - proxyFunc = func(req *http.Request) (*url.URL, error) { - return fn(req.URL) + defaultProxy := http.ProxyFromEnvironment + if s.cfg.proxy != nil { + defaultProxy = func(req *http.Request) (*url.URL, error) { + return s.cfg.proxy.ProxyFunc()(req.URL) + } + } + + defaultDialerBuilder := func() *websocket.Dialer { + return &websocket.Dialer{ + Proxy: defaultProxy, + HandshakeTimeout: s.cfg.dialTimeout, } } ctx := &WebsocketCtx{ cfg: wsConfig{ - autoReconnect: true, - pingTime: time.Second * 30, - dialTimeout: s.cfg.dialTimeout, - writeTimeout: s.cfg.writeTimeout, - ratelimitChan: make(chan struct{}, defaultConcurrency), - rpcInFactory: common.SimpleIncomingJSONRPC, - rpcOutFactory: common.SimpleOutgoingJSONRPC, - dialerBuilder: func() *websocket.Dialer { - return &websocket.Dialer{ - Proxy: proxyFunc, - HandshakeTimeout: s.cfg.dialTimeout, - } - }, + autoReconnect: true, + pingTime: time.Second * 30, + dialTimeout: s.cfg.dialTimeout, + writeTimeout: s.cfg.writeTimeout, + ratelimitChan: make(chan struct{}, defaultConcurrency), + rpcInFactory: common.SimpleIncomingJSONRPC, + rpcOutFactory: common.SimpleOutgoingJSONRPC, + dialerBuilder: defaultDialerBuilder, tracePropagator: s.cfg.tp, }, r: s.r,