From 3c74fbda91793e54d68c52b3e6dd27f74058ec1f Mon Sep 17 00:00:00 2001 From: instabledesign Date: Fri, 27 Sep 2019 15:09:17 +0200 Subject: [PATCH] improve tripperware request id (#17) --- tripperware/request_id.go | 3 ++- tripperware/request_id_test.go | 14 +++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tripperware/request_id.go b/tripperware/request_id.go index d9c64e0..8b57264 100644 --- a/tripperware/request_id.go +++ b/tripperware/request_id.go @@ -19,9 +19,10 @@ func RequestId(config *request_id.Config) httpware.Tripperware { } if id == "" { id = config.IdGenerator(req) + // add requestId header to current request + req.Header.Add(config.HeaderName, id) } r := req.WithContext(context.WithValue(req.Context(), config.HeaderName, id)) - r.Header.Add(config.HeaderName, id) return next.RoundTrip(r) }) } diff --git a/tripperware/request_id_test.go b/tripperware/request_id_test.go index a0ae11a..6787224 100644 --- a/tripperware/request_id_test.go +++ b/tripperware/request_id_test.go @@ -2,8 +2,10 @@ package tripperware_test import ( "fmt" + "math/rand" "net/http" "net/http/httptest" + "os" "testing" "github.com/stretchr/testify/assert" @@ -15,6 +17,14 @@ import ( "github.com/gol4ng/httpware/tripperware" ) +func TestMain(m *testing.M){ + request_id.DefaultIdGenerator = request_id.NewRandomIdGenerator( + rand.New(request_id.NewLockedSource(rand.NewSource(1))), + 10, + ) + os.Exit(m.Run()) +} + func TestRequestId(t *testing.T) { roundTripperMock := &mocks.RoundTripper{} req := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) @@ -27,11 +37,13 @@ func TestRequestId(t *testing.T) { roundTripperMock.On("RoundTrip", mock.AnythingOfType("*http.Request")).Return(resp, nil).Run(func(args mock.Arguments) { innerReq := args.Get(0).(*http.Request) assert.True(t, len(innerReq.Header.Get(request_id.HeaderName)) == 10) + assert.Equal(t, req.Header.Get(request_id.HeaderName), innerReq.Header.Get(request_id.HeaderName)) }) resp2, err := tripperware.RequestId(request_id.NewConfig())(roundTripperMock).RoundTrip(req) assert.Nil(t, err) assert.Equal(t, resp, resp2) + assert.Equal(t, "p1LGIehp1s", req.Header.Get(request_id.HeaderName)) } func TestRequestIdCustom(t *testing.T) { @@ -95,7 +107,7 @@ func ExampleRequestId() { } }() - _, _ = client.Get("http://localhost"+port+"/") + _, _ = client.Get("http://localhost" + port + "/") // Output: server receive request with request id: my-generated-id }