diff --git a/correlation_id/config.go b/correlation_id/config.go new file mode 100644 index 0000000..874f80d --- /dev/null +++ b/correlation_id/config.go @@ -0,0 +1,21 @@ +package correlation_id + +import ( + "net/http" +) + +const HeaderName = "Correlation-Id" + +type Config struct { + HeaderName string + IdGenerator func(*http.Request) string +} + +func NewConfig() *Config { + return &Config{ + HeaderName: HeaderName, + IdGenerator: func(_ *http.Request) string { + return DefaultIdGenerator.Generate(10) + }, + } +} diff --git a/correlation_id/config_test.go b/correlation_id/config_test.go new file mode 100644 index 0000000..cd69cd3 --- /dev/null +++ b/correlation_id/config_test.go @@ -0,0 +1,19 @@ +package correlation_id_test + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/gol4ng/httpware/correlation_id" +) + +func TestNewConfig(t *testing.T) { + defaultRand := rand.New(correlation_id.NewLockedSource(rand.NewSource(1))) + correlation_id.DefaultIdGenerator = correlation_id.NewRandomIdGenerator(defaultRand) + + for _, expectedId := range []string{"p1LGIehp1s", "uqtCDMLxiD"} { + assert.Equal(t, expectedId, correlation_id.NewConfig().IdGenerator(nil)) + } +} diff --git a/request_id/generator.go b/correlation_id/generator.go similarity index 65% rename from request_id/generator.go rename to correlation_id/generator.go index c23dfd4..2eadc65 100644 --- a/request_id/generator.go +++ b/correlation_id/generator.go @@ -1,8 +1,7 @@ -package request_id +package correlation_id import ( "math/rand" - "net/http" "time" "unsafe" ) @@ -16,14 +15,13 @@ const ( ) type RandomIdGenerator struct { - r *rand.Rand - length int + r *rand.Rand } -func (rg *RandomIdGenerator) Generate(_ *http.Request) string { - b := make([]byte, rg.length) +func (rg *RandomIdGenerator) Generate(length int) string { + b := make([]byte, length) // A src.Int63() generates 63 random bits, enough for letterIdxMax characters! - for i, cache, remain := rg.length-1, rg.r.Int63(), letterIdxMax; i >= 0; { + for i, cache, remain := length-1, rg.r.Int63(), letterIdxMax; i >= 0; { if remain == 0 { cache, remain = rg.r.Int63(), letterIdxMax } @@ -37,15 +35,18 @@ func (rg *RandomIdGenerator) Generate(_ *http.Request) string { return *(*string)(unsafe.Pointer(&b)) } -var DefaultRand = rand.New(NewLockedSource(rand.NewSource(time.Now().UTC().UnixNano()))) var DefaultIdGenerator = NewRandomIdGenerator( - DefaultRand, - 10, + rand.New( + NewLockedSource( + rand.NewSource( + time.Now().UTC().UnixNano(), + ), + ), + ), ) -func NewRandomIdGenerator(rand *rand.Rand, length int) *RandomIdGenerator { +func NewRandomIdGenerator(rand *rand.Rand) *RandomIdGenerator { return &RandomIdGenerator{ - r: rand, - length: length, + r: rand, } } diff --git a/correlation_id/generator_test.go b/correlation_id/generator_test.go new file mode 100644 index 0000000..5cdd4cd --- /dev/null +++ b/correlation_id/generator_test.go @@ -0,0 +1,22 @@ +package correlation_id_test + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/gol4ng/httpware/correlation_id" +) + +func Test_Random(t *testing.T) { + assert.Equal(t, 10, len(correlation_id.DefaultIdGenerator.Generate(10))) +} + +func Test_Random_NewSource(t *testing.T) { + r := rand.New(correlation_id.NewLockedSource(rand.NewSource(1))) + rg := correlation_id.NewRandomIdGenerator(r) + for _, expectedId := range []string{"DHIMG9FpXzp1LGIehp1s", "zAHyfjXUlrGhblT7txWd"} { + assert.Equal(t, expectedId, rg.Generate(20)) + } +} diff --git a/request_id/rand.go b/correlation_id/rand.go similarity index 96% rename from request_id/rand.go rename to correlation_id/rand.go index 2b663ff..437cf43 100644 --- a/request_id/rand.go +++ b/correlation_id/rand.go @@ -1,4 +1,4 @@ -package request_id +package correlation_id import ( "math/rand" diff --git a/request_id/rand_test.go b/correlation_id/rand_test.go similarity index 73% rename from request_id/rand_test.go rename to correlation_id/rand_test.go index fea8969..844793c 100644 --- a/request_id/rand_test.go +++ b/correlation_id/rand_test.go @@ -1,15 +1,16 @@ -package request_id_test +package correlation_id_test import ( - "github.com/stretchr/testify/assert" "math/rand" "testing" - "github.com/gol4ng/httpware/request_id" + "github.com/stretchr/testify/assert" + + "github.com/gol4ng/httpware/correlation_id" ) func Test_LockedSource_Int63(t *testing.T) { - s := request_id.NewLockedSource(rand.NewSource(1)) + s := correlation_id.NewLockedSource(rand.NewSource(1)) for _, v := range []int64{5577006791947779410, 8674665223082153551} { assert.Equal(t, v, s.Int63()) @@ -17,7 +18,7 @@ func Test_LockedSource_Int63(t *testing.T) { } func Test_LockedSource_Uint64(t *testing.T) { - s := request_id.NewLockedSource(rand.NewSource(1)) + s := correlation_id.NewLockedSource(rand.NewSource(1)) for _, v := range []uint64{0x4d65822107fcfd52, 0x78629a0f5f3f164f} { assert.Equal(t, v, s.Uint64()) @@ -25,7 +26,7 @@ func Test_LockedSource_Uint64(t *testing.T) { } func Test_LockedSource_Seed(t *testing.T) { - s := request_id.NewLockedSource(rand.NewSource(1)) + s := correlation_id.NewLockedSource(rand.NewSource(1)) for _, v := range []uint64{0x4d65822107fcfd52, 0x78629a0f5f3f164f} { assert.Equal(t, v, s.Uint64()) diff --git a/middleware/request_id.go b/middleware/correlation_id.go similarity index 78% rename from middleware/request_id.go rename to middleware/correlation_id.go index 43aeeb9..a0e6737 100644 --- a/middleware/request_id.go +++ b/middleware/correlation_id.go @@ -5,12 +5,12 @@ import ( "net/http" "github.com/gol4ng/httpware" - "github.com/gol4ng/httpware/request_id" + "github.com/gol4ng/httpware/correlation_id" ) -// RequestId middleware get request id header if provided or generate a request id +// CorrelationId middleware get request id header if provided or generate a request id // It will add the request ID to request context and add it to response header to -func RequestId(config *request_id.Config) httpware.Middleware { +func CorrelationId(config *correlation_id.Config) httpware.Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) { id := req.Header.Get(config.HeaderName) diff --git a/middleware/request_id_test.go b/middleware/correlation_id_test.go similarity index 50% rename from middleware/request_id_test.go rename to middleware/correlation_id_test.go index d69404e..49f5ae1 100644 --- a/middleware/request_id_test.go +++ b/middleware/correlation_id_test.go @@ -10,16 +10,15 @@ import ( "github.com/stretchr/testify/assert" "github.com/gol4ng/httpware" + "github.com/gol4ng/httpware/correlation_id" "github.com/gol4ng/httpware/middleware" - "github.com/gol4ng/httpware/request_id" ) -func TestRequestId(t *testing.T) { - request_id.DefaultRand = rand.New(request_id.NewLockedSource(rand.NewSource(1))) - request_id.DefaultIdGenerator = request_id.NewRandomIdGenerator( - request_id.DefaultRand, - 10, +func TestCorrelationId(t *testing.T) { + correlation_id.DefaultIdGenerator = correlation_id.NewRandomIdGenerator( + rand.New(correlation_id.NewLockedSource(rand.NewSource(1))), ) + var handlerReq *http.Request req := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) responseWriter := &httptest.ResponseRecorder{} @@ -27,52 +26,27 @@ func TestRequestId(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // not equal because req.WithContext create another request object assert.NotEqual(t, req, r) - assert.Equal(t, "p1LGIehp1s", r.Header.Get(request_id.HeaderName)) + assert.Equal(t, "p1LGIehp1s", r.Header.Get(correlation_id.HeaderName)) handlerReq = r }) - middleware.RequestId(request_id.NewConfig())(handler).ServeHTTP(responseWriter, req) - respHeaderValue := responseWriter.Header().Get(request_id.HeaderName) - reqContextValue := handlerReq.Context().Value(request_id.HeaderName).(string) - assert.Equal(t, "p1LGIehp1s", req.Header.Get(request_id.HeaderName)) + middleware.CorrelationId(correlation_id.NewConfig())(handler).ServeHTTP(responseWriter, req) + respHeaderValue := responseWriter.Header().Get(correlation_id.HeaderName) + reqContextValue := handlerReq.Context().Value(correlation_id.HeaderName).(string) + assert.Equal(t, "p1LGIehp1s", req.Header.Get(correlation_id.HeaderName)) assert.True(t, len(respHeaderValue) == 10) assert.True(t, len(reqContextValue) == 10) assert.True(t, respHeaderValue == reqContextValue) } -func TestRequestIdCustom(t *testing.T) { - var handlerReq *http.Request - req := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) - responseWriter := &httptest.ResponseRecorder{} - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // not equal because req.WithContext create another request object - assert.NotEqual(t, req, r) - assert.Equal(t, "my_fake_request_id", r.Header.Get(request_id.HeaderName)) - handlerReq = r - }) - config := request_id.NewConfig() - config.IdGenerator = func(request *http.Request) string { - return "my_fake_request_id" - } - - middleware.RequestId(config)(handler).ServeHTTP(responseWriter, req) - headerValue := responseWriter.Header().Get(request_id.HeaderName) - reqContextValue := handlerReq.Context().Value(request_id.HeaderName).(string) - assert.NotEqual(t, "", req.Header.Get(request_id.HeaderName)) - assert.Equal(t, "my_fake_request_id", headerValue) - assert.Equal(t, "my_fake_request_id", reqContextValue) - assert.True(t, headerValue == reqContextValue) -} - // ===================================================================================================================== // ========================================= EXAMPLES ================================================================== // ===================================================================================================================== -func ExampleRequestId() { +func ExampleCorrelationId() { port := ":5001" - config := request_id.NewConfig() + config := correlation_id.NewConfig() // you can override default header name config.HeaderName = "my-personal-header-name" // you can override default id generator @@ -83,7 +57,7 @@ func ExampleRequestId() { // we recommend to use MiddlewareStack to simplify managing all wanted middlewares // caution middleware order matters stack := httpware.MiddlewareStack( - middleware.RequestId(config), + middleware.CorrelationId(config), ) srv := http.NewServeMux() diff --git a/request_id/config.go b/request_id/config.go deleted file mode 100644 index 7da4156..0000000 --- a/request_id/config.go +++ /dev/null @@ -1,19 +0,0 @@ -package request_id - -import ( - "net/http" -) - -const HeaderName = "Request-Id" - -type Config struct { - HeaderName string - IdGenerator func(*http.Request) string -} - -func NewConfig() *Config { - return &Config{ - HeaderName: HeaderName, - IdGenerator: DefaultIdGenerator.Generate, - } -} diff --git a/request_id/generator_test.go b/request_id/generator_test.go deleted file mode 100644 index 5db5f4f..0000000 --- a/request_id/generator_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package request_id_test - -import ( - "math/rand" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/gol4ng/httpware/request_id" -) - -func Test_Random(t *testing.T) { - assert.Equal(t, 10, len(request_id.DefaultIdGenerator.Generate(nil))) -} - -func Test_Random_NewSource(t *testing.T) { - r := rand.New(request_id.NewLockedSource(rand.NewSource(1))) - rg := request_id.NewRandomIdGenerator(r, 20) - for _, expectedId := range []string{"DHIMG9FpXzp1LGIehp1s", "zAHyfjXUlrGhblT7txWd"} { - assert.Equal(t, expectedId, rg.Generate(nil)) - } -} diff --git a/tripperware/request_id.go b/tripperware/correlation_id.go similarity index 64% rename from tripperware/request_id.go rename to tripperware/correlation_id.go index 8b57264..6fbdb62 100644 --- a/tripperware/request_id.go +++ b/tripperware/correlation_id.go @@ -5,18 +5,24 @@ import ( "net/http" "github.com/gol4ng/httpware" - "github.com/gol4ng/httpware/request_id" + "github.com/gol4ng/httpware/correlation_id" ) -// RequestId tripperware gets request id header if provided or generates a request id +// CorrelationId tripperware gets request id header if provided or generates a request id // It will add the request ID to request context -func RequestId(config *request_id.Config) httpware.Tripperware { +func CorrelationId(config *correlation_id.Config) httpware.Tripperware { return func(next http.RoundTripper) http.RoundTripper { return httpware.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) { + if v, ok := req.Context().Value(config.HeaderName).(string); ok { + req.Header.Add(config.HeaderName, v) + return next.RoundTrip(req) + } + var id string if req.Header != nil { id = req.Header.Get(config.HeaderName) } + if id == "" { id = config.IdGenerator(req) // add requestId header to current request diff --git a/tripperware/request_id_test.go b/tripperware/correlation_id_test.go similarity index 57% rename from tripperware/request_id_test.go rename to tripperware/correlation_id_test.go index 6787224..eb13136 100644 --- a/tripperware/request_id_test.go +++ b/tripperware/correlation_id_test.go @@ -1,6 +1,7 @@ package tripperware_test import ( + "context" "fmt" "math/rand" "net/http" @@ -12,20 +13,19 @@ import ( "github.com/stretchr/testify/mock" "github.com/gol4ng/httpware" + "github.com/gol4ng/httpware/correlation_id" "github.com/gol4ng/httpware/mocks" - "github.com/gol4ng/httpware/request_id" "github.com/gol4ng/httpware/tripperware" ) func TestMain(m *testing.M){ - request_id.DefaultIdGenerator = request_id.NewRandomIdGenerator( - rand.New(request_id.NewLockedSource(rand.NewSource(1))), - 10, + correlation_id.DefaultIdGenerator = correlation_id.NewRandomIdGenerator( + rand.New(correlation_id.NewLockedSource(rand.NewSource(1))), ) os.Exit(m.Run()) } -func TestRequestId(t *testing.T) { +func TestCorrelationId(t *testing.T) { roundTripperMock := &mocks.RoundTripper{} req := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) resp := &http.Response{ @@ -36,17 +36,42 @@ func TestRequestId(t *testing.T) { roundTripperMock.On("RoundTrip", mock.AnythingOfType("*http.Request")).Return(resp, nil).Run(func(args mock.Arguments) { innerReq := args.Get(0).(*http.Request) - assert.True(t, len(innerReq.Header.Get(request_id.HeaderName)) == 10) - assert.Equal(t, req.Header.Get(request_id.HeaderName), innerReq.Header.Get(request_id.HeaderName)) + assert.Len(t, innerReq.Header.Get(correlation_id.HeaderName), 10) + assert.Equal(t, req.Header.Get(correlation_id.HeaderName), innerReq.Header.Get(correlation_id.HeaderName)) }) - resp2, err := tripperware.RequestId(request_id.NewConfig())(roundTripperMock).RoundTrip(req) + resp2, err := tripperware.CorrelationId(correlation_id.NewConfig())(roundTripperMock).RoundTrip(req) assert.Nil(t, err) assert.Equal(t, resp, resp2) - assert.Equal(t, "p1LGIehp1s", req.Header.Get(request_id.HeaderName)) + assert.Equal(t, "p1LGIehp1s", req.Header.Get(correlation_id.HeaderName)) } -func TestRequestIdCustom(t *testing.T) { +func TestCorrelationId_AlreadyInContext(t *testing.T) { + config := correlation_id.NewConfig() + roundTripperMock := &mocks.RoundTripper{} + req := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) + req = req.WithContext(context.WithValue(req.Context(), config.HeaderName, "my_already_exist_correlation_id")) + + resp := &http.Response{ + Status: "OK", + StatusCode: http.StatusOK, + ContentLength: 30, + } + + roundTripperMock.On("RoundTrip", mock.AnythingOfType("*http.Request")).Return(resp, nil).Run(func(args mock.Arguments) { + innerReq := args.Get(0).(*http.Request) + assert.Equal(t, req, innerReq) + assert.Len(t, innerReq.Header.Get(config.HeaderName), 31) + assert.Equal(t, req.Header.Get(config.HeaderName), innerReq.Header.Get(config.HeaderName)) + }) + + resp2, err := tripperware.CorrelationId(config)(roundTripperMock).RoundTrip(req) + assert.Nil(t, err) + assert.Equal(t, resp, resp2) + assert.Equal(t, "my_already_exist_correlation_id", req.Header.Get(config.HeaderName)) +} + +func TestCorrelationIdCustom(t *testing.T) { roundTripperMock := &mocks.RoundTripper{} req := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) resp := &http.Response{ @@ -57,15 +82,15 @@ func TestRequestIdCustom(t *testing.T) { roundTripperMock.On("RoundTrip", mock.AnythingOfType("*http.Request")).Return(resp, nil).Run(func(args mock.Arguments) { innerReq := args.Get(0).(*http.Request) - assert.Equal(t, "my_fake_request_id", innerReq.Header.Get(request_id.HeaderName)) + assert.Equal(t, "my_fake_correlation", innerReq.Header.Get(correlation_id.HeaderName)) }) - config := request_id.NewConfig() + config := correlation_id.NewConfig() config.IdGenerator = func(request *http.Request) string { - return "my_fake_request_id" + return "my_fake_correlation" } - resp2, err := tripperware.RequestId(config)(roundTripperMock).RoundTrip(req) + resp2, err := tripperware.CorrelationId(config)(roundTripperMock).RoundTrip(req) assert.Nil(t, err) assert.Equal(t, resp, resp2) } @@ -74,9 +99,9 @@ func TestRequestIdCustom(t *testing.T) { // ========================================= EXAMPLES ================================================================== // ===================================================================================================================== -func ExampleRequestId() { +func ExampleCorrelationId() { port := ":5005" - config := request_id.NewConfig() + config := correlation_id.NewConfig() // you can override default header name config.HeaderName = "my-personal-header-name" // you can override default id generator @@ -87,7 +112,7 @@ func ExampleRequestId() { // we recommend to use MiddlewareStack to simplify managing all wanted middleware // caution middleware order matter stack := httpware.TripperwareStack( - tripperware.RequestId(config), + tripperware.CorrelationId(config), ) // create http client using the tripperwareStack as RoundTripper