diff --git a/middleware/find_pattern.go b/middleware/find_pattern.go new file mode 100644 index 00000000..48a0b6b9 --- /dev/null +++ b/middleware/find_pattern.go @@ -0,0 +1,29 @@ +package middleware + +import ( + "net/http" + + "github.com/go-chi/chi/v5" +) + +// Find the route pattern for the request path. +// +// This middleware does not need to be the last middleware to resolve the +// route pattern. The pattern is fully resolved before the request has been +// handled. +func FindPattern(routes chi.Routes, callback func(pattern string)) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + // Find mutates the context so always make a new one + rctx := chi.NewRouteContext() + path := r.URL.Path + op := r.Method + pattern := routes.Find(rctx, op, path) + callback(pattern) + + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } + +} diff --git a/middleware/find_pattern_test.go b/middleware/find_pattern_test.go new file mode 100644 index 00000000..e7b5e8d3 --- /dev/null +++ b/middleware/find_pattern_test.go @@ -0,0 +1,67 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" +) + +func TestFindPattern(t *testing.T) { + t.Parallel() + + var tests = []struct { + pattern string + path string + }{ + { + "/", + "/", + }, + { + "/hi", + "/hi", + }, + { + "/{id}", + "/123", + }, + { + "/{id}/hello", + "/123/hello", + }, + { + "/users/*", + "/users/123", + }, + { + "/users/*", + "/users/123/hello", + }, + } + + for _, tt := range tests { + var tt = tt + t.Run(tt.pattern, func(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + + r := chi.NewRouter() + r.Use(FindPattern(r, func(pattern string) { + if pattern != tt.pattern { + t.Errorf("actual pattern \"%s\" does not equal expected pattern \"%s\"", pattern, tt.pattern) + } + })) + + r.Get(tt.pattern, func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("")) + }) + + req := httptest.NewRequest("GET", tt.path, nil) + r.ServeHTTP(recorder, req) + recorder.Result() + }) + } +}