diff --git a/router.go b/router.go index 4893950d..0faf5dee 100644 --- a/router.go +++ b/router.go @@ -76,9 +76,7 @@ // thirdValue := ps[2].Value // the value of the 3rd parameter package httprouter -import ( - "net/http" -) +import "net/http" // Handle is a function that can be registered to a route to handle HTTP // requests. Like http.HandlerFunc, but has a third parameter for the values of @@ -112,6 +110,12 @@ func (ps Params) ByName(name string) string { type Router struct { trees map[string]*node + // If enabled, routing will always use the original request path, not the + // unescaped one. For example if a /users/:user handler is used and + // /users/foo%2fbar is requested, the handler will be called with user=foo%2fbar + // but if this option is disabled, /users/foo/bar will be looked up instead. + RawPathRouting bool + // Enables automatic redirection if the current route can't be matched but a // handler for the path with (without) the trailing slash exists. // For example if /foo/ is requested but a route only exists for /foo, the @@ -268,6 +272,22 @@ func (r *Router) recv(w http.ResponseWriter, req *http.Request) { } } +func (r *Router) requestPath(req *http.Request) string { + if !r.RawPathRouting { + return req.URL.Path + } + path := req.RequestURI + pathLen := len(path) + if pathLen <= 0 { + return path + } + rawQueryLen := len(req.URL.RawQuery) + if rawQueryLen == 0 && path[pathLen-1] != '?' { + return path + } + return path[:pathLen-rawQueryLen-1] +} + // Lookup allows the manual lookup of a method + path combo. // This is e.g. useful to build a framework around this router. // If the path was found, it returns the handle function and the path parameter @@ -287,7 +307,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if root := r.trees[req.Method]; root != nil { - path := req.URL.Path + path := r.requestPath(req) if handle, ps, tsr := root.getValue(path); handle != nil { handle(w, req, ps) @@ -302,11 +322,11 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { if tsr && r.RedirectTrailingSlash { if len(path) > 1 && path[len(path)-1] == '/' { - req.URL.Path = path[:len(path)-1] + path = path[:len(path)-1] } else { - req.URL.Path = path + "/" + path = path + "/" } - http.Redirect(w, req, req.URL.String(), code) + http.Redirect(w, req, path, code) return } @@ -317,8 +337,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { r.RedirectTrailingSlash, ) if found { - req.URL.Path = string(fixedPath) - http.Redirect(w, req, req.URL.String(), code) + path = string(fixedPath) + http.Redirect(w, req, path, code) return } } @@ -333,7 +353,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { continue } - handle, _, _ := r.trees[method].getValue(req.URL.Path) + path := r.requestPath(req) + handle, _, _ := r.trees[method].getValue(path) if handle != nil { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), diff --git a/router_test.go b/router_test.go index 7b07dfba..5ed713cf 100644 --- a/router_test.go +++ b/router_test.go @@ -5,8 +5,11 @@ package httprouter import ( + "bufio" + "bytes" "errors" "fmt" + "net/http" "net/http/httptest" "reflect" @@ -29,6 +32,21 @@ func (m *mockResponseWriter) WriteString(s string) (n int, err error) { func (m *mockResponseWriter) WriteHeader(int) {} +func newRequest(t *testing.T, method, requestURI string) *http.Request { + buf := bytes.NewBuffer(nil) + reader := bufio.NewReader(buf) + buf.WriteString(fmt.Sprintf("%s %s HTTP/1.0\n", method, requestURI)) + buf.WriteString("\n") + req, err := http.ReadRequest(reader) + if err != nil { + req, err = http.NewRequest(method, requestURI, nil) + if err != nil { + t.Fatal(err) + } + } + return req +} + func TestParams(t *testing.T) { ps := Params{ Param{"param1", "value1"}, @@ -59,7 +77,7 @@ func TestRouter(t *testing.T) { w := new(mockResponseWriter) - req, _ := http.NewRequest("GET", "/user/gopher", nil) + req := newRequest(t, "GET", "/user/gopher") router.ServeHTTP(w, req) if !routed { @@ -67,6 +85,104 @@ func TestRouter(t *testing.T) { } } +func TestRawPathRouting(t *testing.T) { + router := New() + router.RawPathRouting = true + + routed := false + router.Handle("GET", "/path/:id", func(w http.ResponseWriter, r *http.Request, ps Params) { + routed = true + want := Params{Param{"id", "go%2fpher"}} + if !reflect.DeepEqual(ps, want) { + t.Fatalf("wrong wildcard values: want %v, got %v", want, ps) + } + }) + + w := new(mockResponseWriter) + + req := newRequest(t, "GET", "/path/go%2fpher") + router.ServeHTTP(w, req) + if !routed { + t.Fatal("routing failed") + } +} + +func TestRawPathRoutingMixed(t *testing.T) { + router := New() + router.RawPathRouting = true + + routed := false + router.Handle("GET", "/u/:u/pher/p/:p", func(w http.ResponseWriter, r *http.Request, ps Params) { + routed = true + want := Params{Param{"u", "go%2fpher"}, Param{"p", "pher%2fgo"}} + if !reflect.DeepEqual(ps, want) { + t.Fatalf("wrong wildcard values: want %v, got %v", want, ps) + } + }) + + w := new(mockResponseWriter) + + req := newRequest(t, "GET", "/u/go%2fpher/pher/p/pher%2fgo") + router.ServeHTTP(w, req) + if !routed { + t.Fatal("routing failed") + } +} + +func TestRawPathRoutingCleanPath(t *testing.T) { + router := New() + router.RawPathRouting = true + + routed := false + router.Handle("GET", "/u/:u/pher/p/:p", func(w http.ResponseWriter, r *http.Request, ps Params) { + routed = true + want := Params{Param{"u", "."}, Param{"p", ".."}} + if !reflect.DeepEqual(ps, want) { + t.Fatalf("wrong wildcard values: want %v, got %v", want, ps) + } + }) + + w := new(mockResponseWriter) + + req := newRequest(t, "GET", "/u/./pher/p/..") + router.ServeHTTP(w, req) + if !routed { + t.Fatal("routing failed") + } +} + +func TestRawPathRoutingNotFound(t *testing.T) { + handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {} + + router := New() + router.RawPathRouting = true + router.GET("/path/:id", handlerFunc) + router.GET("/dir/:id/", handlerFunc) + + testRoutes := []struct { + route string + code int + header string + }{ + {"/path/go%2fpher/", 301, "map[Location:[/path/go%2fpher]]"}, // TSR -/ + {"/dir/go%2fpher", 301, "map[Location:[/dir/go%2fpher/]]"}, // TSR +/ + {"/PATH/go%2fpher", 301, "map[Location:[/path/go%2fpher]]"}, // Fixed Case + {"/DIR/go%2fpher/", 301, "map[Location:[/dir/go%2fpher/]]"}, // Fixed Case + {"/PATH/go%2fpher/", 301, "map[Location:[/path/go%2fpher]]"}, // Fixed Case -/ + {"/DIR/go%2fpher", 301, "map[Location:[/dir/go%2fpher/]]"}, // Fixed Case +/ + {"/../path/go%2fpher", 301, "map[Location:[/path/go%2fpher]]"}, // CleanPath + {"/nope", 404, ""}, // NotFound + } + for _, tr := range testRoutes { + r := newRequest(t, "GET", tr.route) + w := httptest.NewRecorder() + router.ServeHTTP(w, r) + if !(w.Code == tr.code && (w.Code == 404 || fmt.Sprint(w.Header()) == tr.header)) { + t.Errorf("NotFound handling route %s failed: Code=%d, Header=%v", tr.route, w.Code, w.Header()) + } + } +} + type handlerStruct struct { handeled *bool } @@ -106,49 +222,49 @@ func TestRouterAPI(t *testing.T) { w := new(mockResponseWriter) - r, _ := http.NewRequest("GET", "/GET", nil) + r := newRequest(t, "GET", "/GET") router.ServeHTTP(w, r) if !get { t.Error("routing GET failed") } - r, _ = http.NewRequest("HEAD", "/GET", nil) + r = newRequest(t, "HEAD", "/GET") router.ServeHTTP(w, r) if !head { t.Error("routing HEAD failed") } - r, _ = http.NewRequest("POST", "/POST", nil) + r = newRequest(t, "POST", "/POST") router.ServeHTTP(w, r) if !post { t.Error("routing POST failed") } - r, _ = http.NewRequest("PUT", "/PUT", nil) + r = newRequest(t, "PUT", "/PUT") router.ServeHTTP(w, r) if !put { t.Error("routing PUT failed") } - r, _ = http.NewRequest("PATCH", "/PATCH", nil) + r = newRequest(t, "PATCH", "/PATCH") router.ServeHTTP(w, r) if !patch { t.Error("routing PATCH failed") } - r, _ = http.NewRequest("DELETE", "/DELETE", nil) + r = newRequest(t, "DELETE", "/DELETE") router.ServeHTTP(w, r) if !delete { t.Error("routing DELETE failed") } - r, _ = http.NewRequest("GET", "/Handler", nil) + r = newRequest(t, "GET", "/Handler") router.ServeHTTP(w, r) if !handler { t.Error("routing Handler failed") } - r, _ = http.NewRequest("GET", "/HandlerFunc", nil) + r = newRequest(t, "GET", "/HandlerFunc") router.ServeHTTP(w, r) if !handlerFunc { t.Error("routing HandlerFunc failed") @@ -172,7 +288,7 @@ func TestRouterNotAllowed(t *testing.T) { router.POST("/path", handlerFunc) // Test not allowed - r, _ := http.NewRequest("GET", "/path", nil) + r := newRequest(t, "GET", "/path") w := httptest.NewRecorder() router.ServeHTTP(w, r) if !(w.Code == http.StatusMethodNotAllowed) { @@ -204,7 +320,7 @@ func TestRouterNotFound(t *testing.T) { {"/nope", 404, ""}, // NotFound } for _, tr := range testRoutes { - r, _ := http.NewRequest("GET", tr.route, nil) + r := newRequest(t, "GET", tr.route) w := httptest.NewRecorder() router.ServeHTTP(w, r) if !(w.Code == tr.code && (w.Code == 404 || fmt.Sprint(w.Header()) == tr.header)) { @@ -218,7 +334,7 @@ func TestRouterNotFound(t *testing.T) { rw.WriteHeader(404) notFound = true } - r, _ := http.NewRequest("GET", "/nope", nil) + r := newRequest(t, "GET", "/nope") w := httptest.NewRecorder() router.ServeHTTP(w, r) if !(w.Code == 404 && notFound == true) { @@ -227,7 +343,7 @@ func TestRouterNotFound(t *testing.T) { // Test other method than GET (want 307 instead of 301) router.PATCH("/path", handlerFunc) - r, _ = http.NewRequest("PATCH", "/path/", nil) + r = newRequest(t, "PATCH", "/path/") w = httptest.NewRecorder() router.ServeHTTP(w, r) if !(w.Code == 307 && fmt.Sprint(w.Header()) == "map[Location:[/path]]") { @@ -237,7 +353,7 @@ func TestRouterNotFound(t *testing.T) { // Test special case where no node for the prefix "/" exists router = New() router.GET("/a", handlerFunc) - r, _ = http.NewRequest("GET", "/", nil) + r = newRequest(t, "GET", "/") w = httptest.NewRecorder() router.ServeHTTP(w, r) if !(w.Code == 404) { @@ -258,7 +374,7 @@ func TestRouterPanicHandler(t *testing.T) { }) w := new(mockResponseWriter) - req, _ := http.NewRequest("PUT", "/user/gopher", nil) + req := newRequest(t, "PUT", "/user/gopher") defer func() { if rcv := recover(); rcv != nil { @@ -347,7 +463,7 @@ func TestRouterServeFiles(t *testing.T) { router.ServeFiles("/*filepath", mfs) w := new(mockResponseWriter) - r, _ := http.NewRequest("GET", "/favicon.ico", nil) + r := newRequest(t, "GET", "/favicon.ico") router.ServeHTTP(w, r) if !mfs.opened { t.Error("serving file failed")