Skip to content

Commit

Permalink
#226 added new method NewWithLocalAddr to dial from local address and…
Browse files Browse the repository at this point in the history
… test cases update
  • Loading branch information
jeevatkm committed Mar 12, 2019
1 parent 551e301 commit 97a5dbf
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 45 deletions.
64 changes: 36 additions & 28 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,10 @@ func (c *Client) AddRetryCondition(condition RetryConditionFunc) *Client {
//
// // or One can disable security check (https)
// client.SetTLSClientConfig(&tls.Config{ InsecureSkipVerify: true })
//
// Note: This method overwrites existing `TLSClientConfig`.
func (c *Client) SetTLSClientConfig(config *tls.Config) *Client {
transport, err := c.getTransport()
transport, err := c.transport()
if err != nil {
c.Log.Printf("ERROR %v", err)
return c
Expand All @@ -552,7 +553,7 @@ func (c *Client) SetTLSClientConfig(config *tls.Config) *Client {
//
// Refer to godoc `http.ProxyFromEnvironment`.
func (c *Client) SetProxy(proxyURL string) *Client {
transport, err := c.getTransport()
transport, err := c.transport()
if err != nil {
c.Log.Printf("ERROR %v", err)
return c
Expand All @@ -572,7 +573,7 @@ func (c *Client) SetProxy(proxyURL string) *Client {
// RemoveProxy method removes the proxy configuration from Resty client
// client.RemoveProxy()
func (c *Client) RemoveProxy() *Client {
transport, err := c.getTransport()
transport, err := c.transport()
if err != nil {
c.Log.Printf("ERROR %v", err)
return c
Expand All @@ -584,7 +585,7 @@ func (c *Client) RemoveProxy() *Client {

// SetCertificates method helps to set client certificates into Resty conveniently.
func (c *Client) SetCertificates(certs ...tls.Certificate) *Client {
config, err := c.getTLSConfig()
config, err := c.tlsConfig()
if err != nil {
c.Log.Printf("ERROR %v", err)
return c
Expand All @@ -602,7 +603,7 @@ func (c *Client) SetRootCertificate(pemFilePath string) *Client {
return c
}

config, err := c.getTLSConfig()
config, err := c.tlsConfig()
if err != nil {
c.Log.Printf("ERROR %v", err)
return c
Expand Down Expand Up @@ -838,8 +839,8 @@ func (c *Client) execute(req *Request) (*Response, error) {
}

// getting TLS client config if not exists then create one
func (c *Client) getTLSConfig() (*tls.Config, error) {
transport, err := c.getTransport()
func (c *Client) tlsConfig() (*tls.Config, error) {
transport, err := c.transport()
if err != nil {
return nil, err
}
Expand All @@ -849,26 +850,9 @@ func (c *Client) getTLSConfig() (*tls.Config, error) {
return transport.TLSClientConfig, nil
}

// returns `*http.Transport` currently in use or error
// in case currently used `transport` is not an `*http.Transport`
func (c *Client) getTransport() (*http.Transport, error) {
if c.httpClient.Transport == nil {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1,
}
c.SetTransport(transport)
}

// Transport method returns `*http.Transport` currently in use or error
// in case currently used `transport` is not a `*http.Transport`.
func (c *Client) transport() (*http.Transport, error) {
if transport, ok := c.httpClient.Transport.(*http.Transport); ok {
return transport, nil
}
Expand Down Expand Up @@ -908,7 +892,11 @@ type MultipartField struct {
//_______________________________________________________________________

func createClient(hc *http.Client) *Client {
c := &Client{ // not setting default values
if hc.Transport == nil {
hc.Transport = createTransport(nil)
}

c := &Client{ // not setting lang default values
QueryParam: url.Values{},
FormData: url.Values{},
Header: http.Header{},
Expand Down Expand Up @@ -948,3 +936,23 @@ func createClient(hc *http.Client) *Client {

return c
}

func createTransport(localAddr net.Addr) *http.Transport {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}
if localAddr != nil {
dialer.LocalAddr = localAddr
}
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1,
}
}
26 changes: 20 additions & 6 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/tls"
"errors"
"io/ioutil"
"net"
"net/http"
"net/url"
"path/filepath"
Expand Down Expand Up @@ -148,7 +149,7 @@ func TestClientSetCertificates(t *testing.T) {
client := dc()
client.SetCertificates(tls.Certificate{})

transport, err := client.getTransport()
transport, err := client.transport()

assertNil(t, err)
assertEqual(t, 1, len(transport.TLSClientConfig.Certificates))
Expand All @@ -158,7 +159,7 @@ func TestClientSetRootCertificate(t *testing.T) {
client := dc()
client.SetRootCertificate(filepath.Join(getTestDataPath(), "sample-root.pem"))

transport, err := client.getTransport()
transport, err := client.transport()

assertNil(t, err)
assertNotNil(t, transport.TLSClientConfig.RootCAs)
Expand All @@ -168,7 +169,7 @@ func TestClientSetRootCertificateNotExists(t *testing.T) {
client := dc()
client.SetRootCertificate(filepath.Join(getTestDataPath(), "not-exists-sample-root.pem"))

transport, err := client.getTransport()
transport, err := client.transport()

assertNil(t, err)
assertNil(t, transport.TLSClientConfig)
Expand Down Expand Up @@ -207,7 +208,7 @@ func TestClientSetTransport(t *testing.T) {
},
}
client.SetTransport(transport)
transportInUse, err := client.getTransport()
transportInUse, err := client.transport()

assertNil(t, err)
assertEqual(t, true, transport == transportInUse)
Expand Down Expand Up @@ -308,7 +309,7 @@ func TestClientOptions(t *testing.T) {
}

client.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true})
transport, transportErr := client.getTransport()
transport, transportErr := client.transport()

assertNil(t, transportErr)
assertEqual(t, true, transport.TLSClientConfig.InsecureSkipVerify)
Expand Down Expand Up @@ -383,7 +384,7 @@ func TestClientRoundTripper(t *testing.T) {
rt := &CustomRoundTripper{}
c.SetTransport(rt)

ct, err := c.getTransport()
ct, err := c.transport()
assertNotNil(t, err)
assertNil(t, ct)
assertEqual(t, "current transport is not an *http.Transport instance", err.Error())
Expand Down Expand Up @@ -523,3 +524,16 @@ func TestLogCallbacks(t *testing.T) {
assertEqual(t, errors.New("response test error"), err)
assertNotNil(t, resp)
}

func TestNewWithLocalAddr(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

localAddress, _ := net.ResolveTCPAddr("tcp", "127.0.0.1")
client := NewWithLocalAddr(localAddress)
client.SetHostURL(ts.URL)

resp, err := client.R().Get("/")
assertNil(t, err)
assertEqual(t, resp.String(), "TestGet: text response")
}
19 changes: 10 additions & 9 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ func TestRawFileUploadByBody(t *testing.T) {
func TestProxySetting(t *testing.T) {
c := dc()

transport, err := c.getTransport()
transport, err := c.transport()

assertNil(t, err)

Expand Down Expand Up @@ -1440,6 +1440,7 @@ func TestPathParamURLInput(t *testing.T) {
logResponse(t, resp)
}

// This test case is kind of pass always
func TestTraceInfo(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()
Expand All @@ -1453,11 +1454,11 @@ func TestTraceInfo(t *testing.T) {

tr := resp.Request.TraceInfo()
assertEqual(t, true, tr.DNSLookup >= 0)
assertEqual(t, true, tr.ConnTime > 0)
assertEqual(t, true, tr.ConnTime >= 0)
assertEqual(t, true, tr.TLSHandshake >= 0)
assertEqual(t, true, tr.ServerTime > 0)
assertEqual(t, true, tr.ResponseTime > 0)
assertEqual(t, true, tr.TotalTime > 0)
assertEqual(t, true, tr.ServerTime >= 0)
assertEqual(t, true, tr.ResponseTime >= 0)
assertEqual(t, true, tr.TotalTime >= 0)
}

client.DisableTrace()
Expand All @@ -1469,11 +1470,11 @@ func TestTraceInfo(t *testing.T) {

tr := resp.Request.TraceInfo()
assertEqual(t, true, tr.DNSLookup >= 0)
assertEqual(t, true, tr.ConnTime > 0)
assertEqual(t, true, tr.ConnTime >= 0)
assertEqual(t, true, tr.TLSHandshake >= 0)
assertEqual(t, true, tr.ServerTime > 0)
assertEqual(t, true, tr.ResponseTime > 0)
assertEqual(t, true, tr.TotalTime > 0)
assertEqual(t, true, tr.ServerTime >= 0)
assertEqual(t, true, tr.ResponseTime >= 0)
assertEqual(t, true, tr.TotalTime >= 0)
}

// for sake of hook funcs
Expand Down
17 changes: 15 additions & 2 deletions resty.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package resty

import (
"net"
"net/http"
"net/http/cookiejar"

Expand All @@ -18,10 +19,22 @@ const Version = "2.0.0-rc.1"
// New method creates a new Resty client.
func New() *Client {
cookieJar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
return createClient(&http.Client{Jar: cookieJar})
return createClient(&http.Client{
Jar: cookieJar,
})
}

// NewWithClient method create a new Resty client with given `http.Client`.
// NewWithClient method creates a new Resty client with given `http.Client`.
func NewWithClient(hc *http.Client) *Client {
return createClient(hc)
}

// NewWithLocalAddr method creates a new Resty client with given Local Address
// to dial from.
func NewWithLocalAddr(localAddr net.Addr) *Client {
cookieJar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
return createClient(&http.Client{
Jar: cookieJar,
Transport: createTransport(localAddr),
})
}

0 comments on commit 97a5dbf

Please sign in to comment.