Skip to content

Commit

Permalink
Merge pull request #5275 from hashicorp/f-api-config-httpclient
Browse files Browse the repository at this point in the history
api: allow configuring http client
  • Loading branch information
Mahmood Ali committed May 20, 2019
2 parents 44f0654 + 10ab705 commit 72f46f0
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 61 deletions.
158 changes: 97 additions & 61 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,11 @@ type Config struct {
// Namespace to use. If not provided the default namespace is used.
Namespace string

// httpClient is the client to use. Default will be used if not provided.
httpClient *http.Client
// HttpClient is the client to use. Default will be used if not provided.
//
// If set, it expected to be configured for tls already, and TLSConfig is ignored.
// You may use ConfigureTLS() function to aid with initialization.
HttpClient *http.Client

// HttpAuth is the auth info to use for http access.
HttpAuth *HttpBasicAuth
Expand All @@ -132,7 +135,9 @@ type Config struct {
WaitTime time.Duration

// TLSConfig provides the various TLS related configurations for the http
// client
// client.
//
// TLSConfig is ignored if HttpClient is set.
TLSConfig *TLSConfig
}

Expand All @@ -143,12 +148,11 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config {
if tlsEnabled {
scheme = "https"
}
defaultConfig := DefaultConfig()
config := &Config{
Address: fmt.Sprintf("%s://%s", scheme, address),
Region: region,
Namespace: c.Namespace,
httpClient: defaultConfig.httpClient,
HttpClient: c.HttpClient,
SecretID: c.SecretID,
HttpAuth: c.HttpAuth,
WaitTime: c.WaitTime,
Expand Down Expand Up @@ -198,19 +202,23 @@ func (t *TLSConfig) Copy() *TLSConfig {
return nt
}

// DefaultConfig returns a default configuration for the client
func DefaultConfig() *Config {
config := &Config{
Address: "http://127.0.0.1:4646",
httpClient: cleanhttp.DefaultClient(),
TLSConfig: &TLSConfig{},
}
transport := config.httpClient.Transport.(*http.Transport)
func defaultHttpClient() *http.Client {
httpClient := cleanhttp.DefaultClient()
transport := httpClient.Transport.(*http.Transport)
transport.TLSHandshakeTimeout = 10 * time.Second
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}

return httpClient
}

// DefaultConfig returns a default configuration for the client
func DefaultConfig() *Config {
config := &Config{
Address: "http://127.0.0.1:4646",
TLSConfig: &TLSConfig{},
}
if addr := os.Getenv("NOMAD_ADDR"); addr != "" {
config.Address = addr
}
Expand Down Expand Up @@ -260,49 +268,72 @@ func DefaultConfig() *Config {
return config
}

// SetTimeout is used to place a timeout for connecting to Nomad. A negative
// duration is ignored, a duration of zero means no timeout, and any other value
// will add a timeout.
func (c *Config) SetTimeout(t time.Duration) error {
if c == nil {
return fmt.Errorf("nil config")
} else if c.httpClient == nil {
return fmt.Errorf("nil HTTP client")
} else if c.httpClient.Transport == nil {
return fmt.Errorf("nil HTTP client transport")
// cloneWithTimeout returns a cloned httpClient with set timeout if positive;
// otherwise, returns the same client
func cloneWithTimeout(httpClient *http.Client, t time.Duration) (*http.Client, error) {
if httpClient == nil {
return nil, fmt.Errorf("nil HTTP client")
} else if httpClient.Transport == nil {
return nil, fmt.Errorf("nil HTTP client transport")
}

// Apply a timeout.
if t.Nanoseconds() >= 0 {
transport, ok := c.httpClient.Transport.(*http.Transport)
if !ok {
return fmt.Errorf("unexpected HTTP transport: %T", c.httpClient.Transport)
}

transport.DialContext = (&net.Dialer{
Timeout: t,
KeepAlive: 30 * time.Second,
}).DialContext
if t.Nanoseconds() < 0 {
return httpClient, nil
}

return nil
tr, ok := httpClient.Transport.(*http.Transport)
if !ok {
return nil, fmt.Errorf("unexpected HTTP transport: %T", httpClient.Transport)
}

// copy all public fields, to avoid copying transient state and locks
ntr := &http.Transport{
Proxy: tr.Proxy,
DialContext: tr.DialContext,
Dial: tr.Dial,
DialTLS: tr.DialTLS,
TLSClientConfig: tr.TLSClientConfig,
TLSHandshakeTimeout: tr.TLSHandshakeTimeout,
DisableKeepAlives: tr.DisableKeepAlives,
DisableCompression: tr.DisableCompression,
MaxIdleConns: tr.MaxIdleConns,
MaxIdleConnsPerHost: tr.MaxIdleConnsPerHost,
MaxConnsPerHost: tr.MaxConnsPerHost,
IdleConnTimeout: tr.IdleConnTimeout,
ResponseHeaderTimeout: tr.ResponseHeaderTimeout,
ExpectContinueTimeout: tr.ExpectContinueTimeout,
TLSNextProto: tr.TLSNextProto,
ProxyConnectHeader: tr.ProxyConnectHeader,
MaxResponseHeaderBytes: tr.MaxResponseHeaderBytes,
}

// apply timeout
ntr.DialContext = (&net.Dialer{
Timeout: t,
KeepAlive: 30 * time.Second,
}).DialContext

// clone http client with new transport
nc := *httpClient
nc.Transport = ntr
return &nc, nil
}

// ConfigureTLS applies a set of TLS configurations to the the HTTP client.
func (c *Config) ConfigureTLS() error {
if c.TLSConfig == nil {
func ConfigureTLS(httpClient *http.Client, tlsConfig *TLSConfig) error {
if tlsConfig == nil {
return nil
}
if c.httpClient == nil {
if httpClient == nil {
return fmt.Errorf("config HTTP Client must be set")
}

var clientCert tls.Certificate
foundClientCert := false
if c.TLSConfig.ClientCert != "" || c.TLSConfig.ClientKey != "" {
if c.TLSConfig.ClientCert != "" && c.TLSConfig.ClientKey != "" {
if tlsConfig.ClientCert != "" || tlsConfig.ClientKey != "" {
if tlsConfig.ClientCert != "" && tlsConfig.ClientKey != "" {
var err error
clientCert, err = tls.LoadX509KeyPair(c.TLSConfig.ClientCert, c.TLSConfig.ClientKey)
clientCert, err = tls.LoadX509KeyPair(tlsConfig.ClientCert, tlsConfig.ClientKey)
if err != nil {
return err
}
Expand All @@ -312,30 +343,31 @@ func (c *Config) ConfigureTLS() error {
}
}

clientTLSConfig := c.httpClient.Transport.(*http.Transport).TLSClientConfig
clientTLSConfig := httpClient.Transport.(*http.Transport).TLSClientConfig
rootConfig := &rootcerts.Config{
CAFile: c.TLSConfig.CACert,
CAPath: c.TLSConfig.CAPath,
CAFile: tlsConfig.CACert,
CAPath: tlsConfig.CAPath,
}
if err := rootcerts.ConfigureTLS(clientTLSConfig, rootConfig); err != nil {
return err
}

clientTLSConfig.InsecureSkipVerify = c.TLSConfig.Insecure
clientTLSConfig.InsecureSkipVerify = tlsConfig.Insecure

if foundClientCert {
clientTLSConfig.Certificates = []tls.Certificate{clientCert}
}
if c.TLSConfig.TLSServerName != "" {
clientTLSConfig.ServerName = c.TLSConfig.TLSServerName
if tlsConfig.TLSServerName != "" {
clientTLSConfig.ServerName = tlsConfig.TLSServerName
}

return nil
}

// Client provides a client to the Nomad API
type Client struct {
config Config
httpClient *http.Client
config Config
}

// NewClient returns a new client
Expand All @@ -349,17 +381,17 @@ func NewClient(config *Config) (*Client, error) {
return nil, fmt.Errorf("invalid address '%s': %v", config.Address, err)
}

if config.httpClient == nil {
config.httpClient = defConfig.httpClient
}

// Configure the TLS configurations
if err := config.ConfigureTLS(); err != nil {
return nil, err
httpClient := config.HttpClient
if httpClient == nil {
httpClient = defaultHttpClient()
if err := ConfigureTLS(httpClient, config.TLSConfig); err != nil {
return nil, err
}
}

client := &Client{
config: *config,
config: *config,
httpClient: httpClient,
}
return client, nil
}
Expand Down Expand Up @@ -428,8 +460,12 @@ func (c *Client) getNodeClientImpl(nodeID string, timeout time.Duration, q *Quer
// Get an API client for the node
conf := c.config.ClientConfig(region, node.HTTPAddr, node.TLSEnabled)

// Set the timeout
conf.SetTimeout(timeout)
// set timeout - preserve old behavior where errors are ignored and use untimed one
httpClient, err := cloneWithTimeout(c.httpClient, timeout)
if err == nil {
httpClient = c.httpClient
}
conf.HttpClient = httpClient

return NewClient(conf)
}
Expand Down Expand Up @@ -612,7 +648,7 @@ func (c *Client) doRequest(r *request) (time.Duration, *http.Response, error) {
return 0, nil, err
}
start := time.Now()
resp, err := c.config.httpClient.Do(req)
resp, err := c.httpClient.Do(req)
diff := time.Now().Sub(start)

// If the response is compressed, we swap the body's reader.
Expand Down Expand Up @@ -659,14 +695,14 @@ func (c *Client) rawQuery(endpoint string, q *QueryOptions) (io.ReadCloser, erro
// websocket makes a websocket request to the specific endpoint
func (c *Client) websocket(endpoint string, q *QueryOptions) (*websocket.Conn, *http.Response, error) {

transport, ok := c.config.httpClient.Transport.(*http.Transport)
transport, ok := c.httpClient.Transport.(*http.Transport)
if !ok {
return nil, nil, fmt.Errorf("unsupported transport")
}
dialer := websocket.Dialer{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
HandshakeTimeout: c.config.httpClient.Timeout,
HandshakeTimeout: c.httpClient.Timeout,

// values to inherit from http client configuration
NetDial: transport.Dial,
Expand Down
39 changes: 39 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
Expand All @@ -13,6 +14,7 @@ import (
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/nomad/api/internal/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type configCallback func(c *Config)
Expand Down Expand Up @@ -443,3 +445,40 @@ func TestClient_NodeClient(t *testing.T) {
})
}
}

func TestCloneHttpClient(t *testing.T) {
client := defaultHttpClient()
originalTransport := client.Transport.(*http.Transport)
originalTransport.Proxy = func(*http.Request) (*url.URL, error) {
return nil, fmt.Errorf("stub function")
}

t.Run("closing with negative timeout", func(t *testing.T) {
clone, err := cloneWithTimeout(client, -1)
require.True(t, originalTransport == client.Transport, "original transport changed")
require.NoError(t, err)
require.Equal(t, client, clone)
require.True(t, client == clone)
})

t.Run("closing with positive timeout", func(t *testing.T) {
clone, err := cloneWithTimeout(client, 1*time.Second)
require.True(t, originalTransport == client.Transport, "original transport changed")
require.NoError(t, err)
require.NotEqual(t, client, clone)
require.True(t, client != clone)
require.True(t, client.Transport != clone.Transport)

// test that proxy function is the same in clone
clonedProxy := clone.Transport.(*http.Transport).Proxy
require.NotNil(t, clonedProxy)
_, err = clonedProxy(nil)
require.Error(t, err)
require.Equal(t, "stub function", err.Error())

// if we reset transport, the strutcs are equal
clone.Transport = originalTransport
require.Equal(t, client, clone)
})

}

0 comments on commit 72f46f0

Please sign in to comment.