diff --git a/http/client.go b/http/client.go index 2366fe2..fd843a1 100644 --- a/http/client.go +++ b/http/client.go @@ -14,9 +14,8 @@ import ( httpntlmv2 "github.com/vadimi/go-http-ntlm/v2" ) -type Middleware func(http.RoundTripper) http.RoundTripper - type TraceConfig = middlewares.TraceConfig + var TraceAll = TraceConfig{ MaxBodyLength: 4096, Body: true, @@ -36,6 +35,9 @@ var TraceHeaders = TraceConfig{ TLS: false, } +func (a *AuthConfig) IsEmpty() bool { + return a.Username == "" && a.Password == "" +} type AuthConfig struct { // Username for basic Auth @@ -79,6 +81,7 @@ type Client struct { // cacheDNS specifies whether to cache DNS lookups cacheDNS bool + userAgent string } @@ -106,6 +109,11 @@ func (c *Client) R(ctx context.Context) *Request { } } +func (c *Client) UserAgent(agent string) *Client { + c.userAgent = agent + return c +} + // Retry configuration retrying on failure with exponential backoff. // // Base duration of a second & an exponent of 2 is a good option. @@ -211,6 +219,7 @@ func (c *Client) Trace(config TraceConfig) *Client { c.Use(middlewares.NewTracedTransport(config).RoundTripper) return c } + func (c *Client) TraceToStdout(config TraceConfig) *Client { c.Use(middlewares.NewLogger(config)) return c @@ -280,7 +289,7 @@ func (c *Client) roundTrip(r *Request) (resp *Response, err error) { req.URL.RawQuery = queryParam.Encode() req.Host = host - if r.client.authConfig != nil { + if r.client.authConfig != nil && !r.client.authConfig.IsEmpty() { req.SetBasicAuth(r.client.authConfig.Username, r.client.authConfig.Password) } diff --git a/http/examples_test.go b/http/examples_test.go index f32ff84..4b16017 100644 --- a/http/examples_test.go +++ b/http/examples_test.go @@ -123,6 +123,25 @@ func TestExample(t *testing.T) { } }) + t.Run("No Auth", func(t *testing.T) { + resp, err := http.NewClient().R(context.Background()).Header("Hello", "World").Get("https://httpbin.demo.aws.flanksource.com/headers") + if err != nil { + t.Error(err) + } + var headers map[string]any + if body, err := resp.AsJSON(); err != nil { + t.Error(err) + } else { + headers = body["headers"].(map[string]any) + } + if headers["Hello"] != "World" { + t.Errorf("Expected response headers %s", headers) + } + if v, ok := headers["Authorization"]; ok { + t.Errorf("Expecting blank authentication got %s", v) + } + }) + t.Run("Tracing & logging middleware", func(t *testing.T) { client := http.NewClient().Trace(http.TraceConfig{ MaxBodyLength: 4096,