diff --git a/fetcher/fetcher.go b/fetcher/fetcher.go index a8b94b94f..122a11d77 100644 --- a/fetcher/fetcher.go +++ b/fetcher/fetcher.go @@ -85,17 +85,7 @@ func New( serverAddress string, options ...Option, ) *Fetcher { - // Create default fetcher - clientCfg := client.NewConfiguration( - serverAddress, - DefaultUserAgent, - &http.Client{ - Timeout: DefaultHTTPTimeout, - }) - client := client.NewAPIClient(clientCfg) - f := &Fetcher{ - rosettaClient: client, maxConnections: DefaultMaxConnections, maxRetries: DefaultRetries, retryElapsedTime: DefaultElapsedTime, @@ -106,18 +96,34 @@ func New( opt(f) } - // Override transport idle connection settings - // - // See this conversation around why `.Clone()` is used here: - // https://github.com/golang/go/issues/26013 - customTransport := http.DefaultTransport.(*http.Transport).Clone() - customTransport.IdleConnTimeout = DefaultIdleConnTimeout - customTransport.MaxIdleConns = f.maxConnections - customTransport.MaxIdleConnsPerHost = f.maxConnections + if f.rosettaClient == nil { + // Override transport idle connection settings + // + // See this conversation around why `.Clone()` is used here: + // https://github.com/golang/go/issues/26013 + defaultTransport := http.DefaultTransport.(*http.Transport).Clone() + defaultTransport.IdleConnTimeout = DefaultIdleConnTimeout + defaultTransport.MaxIdleConns = f.maxConnections + defaultTransport.MaxIdleConnsPerHost = DefaultMaxConnections + defaultHTTPClient := &http.Client{ + Timeout: DefaultHTTPTimeout, + Transport: defaultTransport, + } + + // Create default fetcher + clientCfg := client.NewConfiguration( + serverAddress, + DefaultUserAgent, + defaultHTTPClient, + ) + f.rosettaClient = client.NewAPIClient(clientCfg) + } + if f.insecureTLS { - customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402 + if transport, ok := f.rosettaClient.GetConfig().HTTPClient.Transport.(*http.Transport); ok { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402 + } } - f.rosettaClient.GetConfig().HTTPClient.Transport = customTransport // Initialize the connection semaphore f.connectionSemaphore = semaphore.NewWeighted(int64(f.maxConnections)) diff --git a/fetcher/fetcher_test.go b/fetcher/fetcher_test.go index 7f2e580c9..98e1e63aa 100644 --- a/fetcher/fetcher_test.go +++ b/fetcher/fetcher_test.go @@ -26,6 +26,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/coinbase/rosetta-sdk-go/asserter" + "github.com/coinbase/rosetta-sdk-go/client" "github.com/coinbase/rosetta-sdk-go/types" ) @@ -195,3 +196,21 @@ func TestInitializeAsserter(t *testing.T) { }) } } + +func TestNewWithHTTPCLient(t *testing.T) { + // Callers can pass an http.Client to + // the fetcher via WithClient. + // Ensure that the fetcher does not + // override it. + httpClient := &http.Client{} + apiClient := client.NewAPIClient( + client.NewConfiguration( + "https://serveraddress", + DefaultUserAgent, + httpClient, + ), + ) + fetcher := New("https://serveraddress", WithClient(apiClient)) + var assert = assert.New(t) + assert.Same(httpClient, fetcher.rosettaClient.GetConfig().HTTPClient) +}