From 04f4f82b8e10d599bb8d66ffbb884bde903ad803 Mon Sep 17 00:00:00 2001 From: Cyril Tovena Date: Thu, 29 Apr 2021 09:17:33 -0400 Subject: [PATCH] Adds the ability to provide a tripperware to Promtail client. (#3654) Fixes #3608 Signed-off-by: Cyril Tovena --- clients/pkg/promtail/client/client.go | 21 +++++++++++++ clients/pkg/promtail/client/client_test.go | 34 ++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/clients/pkg/promtail/client/client.go b/clients/pkg/promtail/client/client.go index 451957a931ce..3b2c6f5381e1 100644 --- a/clients/pkg/promtail/client/client.go +++ b/clients/pkg/promtail/client/client.go @@ -157,8 +157,15 @@ type client struct { cancel context.CancelFunc } +// Tripperware can wrap a roundtripper. +type Tripperware func(http.RoundTripper) http.RoundTripper + // New makes a new Client. func New(reg prometheus.Registerer, cfg Config, logger log.Logger) (Client, error) { + return newClient(reg, cfg, logger) +} + +func newClient(reg prometheus.Registerer, cfg Config, logger log.Logger) (*client, error) { if cfg.URL.URL == nil { return nil, errors.New("client needs target URL") } @@ -199,6 +206,20 @@ func New(reg prometheus.Registerer, cfg Config, logger log.Logger) (Client, erro return c, nil } +// NewWithTripperware creates a new Loki client with a custom tripperware. +func NewWithTripperware(reg prometheus.Registerer, cfg Config, logger log.Logger, tp Tripperware) (Client, error) { + c, err := newClient(reg, cfg, logger) + if err != nil { + return nil, err + } + + if tp != nil { + c.client.Transport = tp(c.client.Transport) + } + + return c, nil +} + func (c *client) run() { batches := map[string]*batch{} diff --git a/clients/pkg/promtail/client/client_test.go b/clients/pkg/promtail/client/client_test.go index 7eb35edb14c5..2b236417270f 100644 --- a/clients/pkg/promtail/client/client_test.go +++ b/clients/pkg/promtail/client/client_test.go @@ -1,9 +1,11 @@ package client import ( + "io" "math" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" @@ -459,3 +461,35 @@ func createServerHandler(receivedReqsChan chan receivedReq, status int) http.Han rw.WriteHeader(status) }) } + +type RoundTripperFunc func(*http.Request) (*http.Response, error) + +func (r RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return r(req) +} + +func Test_Tripperware(t *testing.T) { + url, err := url.Parse("http://foo.com") + require.NoError(t, err) + var called bool + c, err := NewWithTripperware(nil, Config{ + URL: flagext.URLValue{URL: url}, + }, log.NewNopLogger(), func(rt http.RoundTripper) http.RoundTripper { + return RoundTripperFunc(func(r *http.Request) (*http.Response, error) { + require.Equal(t, r.URL.String(), "http://foo.com") + called = true + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + }, nil + }) + }) + require.NoError(t, err) + + c.Chan() <- api.Entry{ + Labels: model.LabelSet{"foo": "bar"}, + Entry: logproto.Entry{Timestamp: time.Now(), Line: "foo"}, + } + c.Stop() + require.True(t, called) +}