Skip to content

Commit

Permalink
api: allow configuring http client
Browse files Browse the repository at this point in the history
Allow clients to configure httpClient, e.g. set a pooled/keep-alive
client.

When caller configures HttpClient explicitly, we aim to use as-is; e.g.
we assume it's configured with TLS already.  Expose `ConfigureTLS` to
aid api consumers with configuring their http client.

Also, removes `SetTimeout` call that I believe is internal only and has
odd side-effects when called on already created config.  Also deprecates
`config.ConfigureTLS` in preference to the new `ConfigureTLS`.
  • Loading branch information
Mahmood Ali committed May 17, 2019
1 parent 9b6e5c1 commit f278760
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 59 deletions.
166 changes: 107 additions & 59 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,11 @@ type Config struct {
// Namespace to use. If not provided the default namespace is used.
Namespace string

// httpClient is the client to use. Default will be used if not provided.
httpClient *http.Client
// HttpClient is the client to use. Default will be used if not provided.
//
// If set, it expected to be configured for tls already, and TLSConfig is ignored.
// You may use ConfigureTLS() function to aid with initialization.
HttpClient *http.Client

// HttpAuth is the auth info to use for http access.
HttpAuth *HttpBasicAuth
Expand All @@ -132,7 +135,9 @@ type Config struct {
WaitTime time.Duration

// TLSConfig provides the various TLS related configurations for the http
// client
// client.
//
// TLSConfig is ignored if HttpClient is set.
TLSConfig *TLSConfig
}

Expand All @@ -143,12 +148,11 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config {
if tlsEnabled {
scheme = "https"
}
defaultConfig := DefaultConfig()
config := &Config{
Address: fmt.Sprintf("%s://%s", scheme, address),
Region: region,
Namespace: c.Namespace,
httpClient: defaultConfig.httpClient,
HttpClient: c.HttpClient,
SecretID: c.SecretID,
HttpAuth: c.HttpAuth,
WaitTime: c.WaitTime,
Expand Down Expand Up @@ -198,19 +202,23 @@ func (t *TLSConfig) Copy() *TLSConfig {
return nt
}

// DefaultConfig returns a default configuration for the client
func DefaultConfig() *Config {
config := &Config{
Address: "http://127.0.0.1:4646",
httpClient: cleanhttp.DefaultClient(),
TLSConfig: &TLSConfig{},
}
transport := config.httpClient.Transport.(*http.Transport)
func defaultHttpClient() *http.Client {
httpClient := cleanhttp.DefaultClient()
transport := httpClient.Transport.(*http.Transport)
transport.TLSHandshakeTimeout = 10 * time.Second
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}

return httpClient
}

// DefaultConfig returns a default configuration for the client
func DefaultConfig() *Config {
config := &Config{
Address: "http://127.0.0.1:4646",
TLSConfig: &TLSConfig{},
}
if addr := os.Getenv("NOMAD_ADDR"); addr != "" {
config.Address = addr
}
Expand Down Expand Up @@ -260,49 +268,84 @@ func DefaultConfig() *Config {
return config
}

// SetTimeout is used to place a timeout for connecting to Nomad. A negative
// duration is ignored, a duration of zero means no timeout, and any other value
// will add a timeout.
func (c *Config) SetTimeout(t time.Duration) error {
if c == nil {
return fmt.Errorf("nil config")
} else if c.httpClient == nil {
return fmt.Errorf("nil HTTP client")
} else if c.httpClient.Transport == nil {
return fmt.Errorf("nil HTTP client transport")
// cloneWithTimeout returns a cloned httpClient with set timeout if positive;
// otherwise, returns the same client
func cloneWithTimeout(httpClient *http.Client, t time.Duration) (*http.Client, error) {
if httpClient == nil {
return nil, fmt.Errorf("nil HTTP client")
} else if httpClient.Transport == nil {
return nil, fmt.Errorf("nil HTTP client transport")
}

// Apply a timeout.
if t.Nanoseconds() >= 0 {
transport, ok := c.httpClient.Transport.(*http.Transport)
if !ok {
return fmt.Errorf("unexpected HTTP transport: %T", c.httpClient.Transport)
}
if t.Nanoseconds() < 0 {
return httpClient, nil
}

transport.DialContext = (&net.Dialer{
Timeout: t,
KeepAlive: 30 * time.Second,
}).DialContext
transport, ok := httpClient.Transport.(*http.Transport)
if !ok {
return nil, fmt.Errorf("unexpected HTTP transport: %T", httpClient.Transport)
}

return nil
// Apply a timeout.
nc := *httpClient

// copy all public fields
ntr := &http.Transport{
Proxy: transport.Proxy,
DialContext: transport.DialContext,
Dial: transport.Dial,
DialTLS: transport.DialTLS,
TLSClientConfig: transport.TLSClientConfig,
TLSHandshakeTimeout: transport.TLSHandshakeTimeout,
DisableKeepAlives: transport.DisableKeepAlives,
DisableCompression: transport.DisableCompression,
MaxIdleConns: transport.MaxIdleConns,
MaxIdleConnsPerHost: transport.MaxIdleConnsPerHost,
MaxConnsPerHost: transport.MaxConnsPerHost,
IdleConnTimeout: transport.IdleConnTimeout,
ResponseHeaderTimeout: transport.ResponseHeaderTimeout,
ExpectContinueTimeout: transport.ExpectContinueTimeout,
TLSNextProto: transport.TLSNextProto,
ProxyConnectHeader: transport.ProxyConnectHeader,
MaxResponseHeaderBytes: transport.MaxResponseHeaderBytes,
}

ntr.DialContext = (&net.Dialer{
Timeout: t,
KeepAlive: 30 * time.Second,
}).DialContext

nc.Transport = ntr
return &nc, nil
}

// ConfigureTLS applies a set of TLS configurations to the the HTTP client.
//
// Deprecated: This method is called internally. Consider using ConfigureTLS instead.
func (c *Config) ConfigureTLS() error {
if c.TLSConfig == nil {

// preserve backward behavior where ConfigureTLS pre0.9 always had a client
if c.HttpClient == nil {
c.HttpClient = defaultHttpClient()
}
return ConfigureTLS(c.HttpClient, c.TLSConfig)
}

// ConfigureTLS applies a set of TLS configurations to the the HTTP client.
func ConfigureTLS(httpClient *http.Client, tlsConfig *TLSConfig) error {
if tlsConfig == nil {
return nil
}
if c.httpClient == nil {
if httpClient == nil {
return fmt.Errorf("config HTTP Client must be set")
}

var clientCert tls.Certificate
foundClientCert := false
if c.TLSConfig.ClientCert != "" || c.TLSConfig.ClientKey != "" {
if c.TLSConfig.ClientCert != "" && c.TLSConfig.ClientKey != "" {
if tlsConfig.ClientCert != "" || tlsConfig.ClientKey != "" {
if tlsConfig.ClientCert != "" && tlsConfig.ClientKey != "" {
var err error
clientCert, err = tls.LoadX509KeyPair(c.TLSConfig.ClientCert, c.TLSConfig.ClientKey)
clientCert, err = tls.LoadX509KeyPair(tlsConfig.ClientCert, tlsConfig.ClientKey)
if err != nil {
return err
}
Expand All @@ -312,30 +355,31 @@ func (c *Config) ConfigureTLS() error {
}
}

clientTLSConfig := c.httpClient.Transport.(*http.Transport).TLSClientConfig
clientTLSConfig := httpClient.Transport.(*http.Transport).TLSClientConfig
rootConfig := &rootcerts.Config{
CAFile: c.TLSConfig.CACert,
CAPath: c.TLSConfig.CAPath,
CAFile: tlsConfig.CACert,
CAPath: tlsConfig.CAPath,
}
if err := rootcerts.ConfigureTLS(clientTLSConfig, rootConfig); err != nil {
return err
}

clientTLSConfig.InsecureSkipVerify = c.TLSConfig.Insecure
clientTLSConfig.InsecureSkipVerify = tlsConfig.Insecure

if foundClientCert {
clientTLSConfig.Certificates = []tls.Certificate{clientCert}
}
if c.TLSConfig.TLSServerName != "" {
clientTLSConfig.ServerName = c.TLSConfig.TLSServerName
if tlsConfig.TLSServerName != "" {
clientTLSConfig.ServerName = tlsConfig.TLSServerName
}

return nil
}

// Client provides a client to the Nomad API
type Client struct {
config Config
httpClient *http.Client
config Config
}

// NewClient returns a new client
Expand All @@ -349,17 +393,17 @@ func NewClient(config *Config) (*Client, error) {
return nil, fmt.Errorf("invalid address '%s': %v", config.Address, err)
}

if config.httpClient == nil {
config.httpClient = defConfig.httpClient
}

// Configure the TLS configurations
if err := config.ConfigureTLS(); err != nil {
return nil, err
httpClient := config.HttpClient
if httpClient == nil {
httpClient = defaultHttpClient()
if err := ConfigureTLS(httpClient, config.TLSConfig); err != nil {
return nil, err
}
}

client := &Client{
config: *config,
config: *config,
httpClient: httpClient,
}
return client, nil
}
Expand Down Expand Up @@ -428,8 +472,12 @@ func (c *Client) getNodeClientImpl(nodeID string, timeout time.Duration, q *Quer
// Get an API client for the node
conf := c.config.ClientConfig(region, node.HTTPAddr, node.TLSEnabled)

// Set the timeout
conf.SetTimeout(timeout)
// set timeout - preserve old behavior where errors are ignored and use untimed one
httpClient, err := cloneWithTimeout(c.httpClient, timeout)
if err == nil {
httpClient = c.httpClient
}
conf.HttpClient = httpClient

return NewClient(conf)
}
Expand Down Expand Up @@ -612,7 +660,7 @@ func (c *Client) doRequest(r *request) (time.Duration, *http.Response, error) {
return 0, nil, err
}
start := time.Now()
resp, err := c.config.httpClient.Do(req)
resp, err := c.httpClient.Do(req)
diff := time.Now().Sub(start)

// If the response is compressed, we swap the body's reader.
Expand Down Expand Up @@ -659,14 +707,14 @@ func (c *Client) rawQuery(endpoint string, q *QueryOptions) (io.ReadCloser, erro
// websocket makes a websocket request to the specific endpoint
func (c *Client) websocket(endpoint string, q *QueryOptions) (*websocket.Conn, *http.Response, error) {

transport, ok := c.config.httpClient.Transport.(*http.Transport)
transport, ok := c.httpClient.Transport.(*http.Transport)
if !ok {
return nil, nil, fmt.Errorf("unsupported transport")
}
dialer := websocket.Dialer{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
HandshakeTimeout: c.config.httpClient.Timeout,
HandshakeTimeout: c.httpClient.Timeout,

// values to inherit from http client configuration
NetDial: transport.Dial,
Expand Down
39 changes: 39 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
Expand All @@ -13,6 +14,7 @@ import (
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/nomad/api/internal/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type configCallback func(c *Config)
Expand Down Expand Up @@ -443,3 +445,40 @@ func TestClient_NodeClient(t *testing.T) {
})
}
}

func TestCloneHttpClient(t *testing.T) {
client := defaultHttpClient()
originalTransport := client.Transport.(*http.Transport)
originalTransport.Proxy = func(*http.Request) (*url.URL, error) {
return nil, fmt.Errorf("stub function")
}

t.Run("closing with negative timeout", func(t *testing.T) {
clone, err := cloneWithTimeout(client, -1)
require.True(t, originalTransport == client.Transport, "original transport changed")
require.NoError(t, err)
require.Equal(t, client, clone)
require.True(t, client == clone)
})

t.Run("closing with positive timeout", func(t *testing.T) {
clone, err := cloneWithTimeout(client, 1*time.Second)
require.True(t, originalTransport == client.Transport, "original transport changed")
require.NoError(t, err)
require.NotEqual(t, client, clone)
require.True(t, client != clone)
require.True(t, client.Transport != clone.Transport)

// test that proxy function is the same in clone
clonedProxy := clone.Transport.(*http.Transport).Proxy
require.NotNil(t, clonedProxy)
_, err = clonedProxy(nil)
require.Error(t, err)
require.Equal(t, "stub function", err.Error())

// if we reset transport, the strutcs are equal
clone.Transport = originalTransport
require.Equal(t, client, clone)
})

}

0 comments on commit f278760

Please sign in to comment.