diff --git a/middleware/realip.go b/middleware/realip.go index d91f2d1c..55c95a89 100644 --- a/middleware/realip.go +++ b/middleware/realip.go @@ -9,11 +9,9 @@ import ( "strings" ) -var defaultHeaders = []string{ - "True-Client-IP", // Cloudflare Enterprise plan - "X-Real-IP", - "X-Forwarded-For", -} +var trueClientIP = http.CanonicalHeaderKey("True-Client-IP") +var xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") +var xRealIP = http.CanonicalHeaderKey("X-Real-IP") // RealIP is a middleware that sets a http.Request's RemoteAddr to the results // of parsing either the True-Client-IP, X-Real-IP or the X-Forwarded-For headers @@ -32,7 +30,7 @@ var defaultHeaders = []string{ // how you're using RemoteAddr, vulnerable to an attack of some sort). func RealIP(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { - if rip := getRealIP(r, defaultHeaders); rip != "" { + if rip := realIP(r); rip != "" { r.RemoteAddr = rip } h.ServeHTTP(w, r) @@ -41,33 +39,22 @@ func RealIP(h http.Handler) http.Handler { return http.HandlerFunc(fn) } -// RealIPFromHeaders is a middleware that sets a http.Request's RemoteAddr to the results -// of parsing the custom headers. -// -// usage: -// r.Use(RealIPFromHeaders("CF-Connecting-IP")) -func RealIPFromHeaders(headers ...string) func(http.Handler) http.Handler { - f := func(h http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - if rip := getRealIP(r, headers); rip != "" { - r.RemoteAddr = rip - } - h.ServeHTTP(w, r) - } - return http.HandlerFunc(fn) - } - return f -} +func realIP(r *http.Request) string { + var ip string -func getRealIP(r *http.Request, headers []string) string { - for _, header := range headers { - if ip := r.Header.Get(header); ip != "" { - ips := strings.Split(ip, ",") - if ips[0] == "" || net.ParseIP(ips[0]) == nil { - continue - } - return ips[0] + if tcip := r.Header.Get(trueClientIP); tcip != "" { + ip = tcip + } else if xrip := r.Header.Get(xRealIP); xrip != "" { + ip = xrip + } else if xff := r.Header.Get(xForwardedFor); xff != "" { + i := strings.Index(xff, ",") + if i == -1 { + i = len(xff) } + ip = xff[:i] + } + if ip == "" || net.ParseIP(ip) == nil { + return "" } - return "" + return ip } diff --git a/middleware/realip_test.go b/middleware/realip_test.go index 97370323..1ab5e95e 100644 --- a/middleware/realip_test.go +++ b/middleware/realip_test.go @@ -113,52 +113,3 @@ func TestInvalidIP(t *testing.T) { t.Fatal("Invalid IP used.") } } - -func TestCustomIPHeader(t *testing.T) { - var customHeaderKey = "X-CUSTOM-IP" - req, _ := http.NewRequest("GET", "/", nil) - req.Header.Add(customHeaderKey, "100.100.100.100") - w := httptest.NewRecorder() - - r := chi.NewRouter() - r.Use(RealIPFromHeaders(customHeaderKey)) - - realIP := "" - r.Get("/", func(w http.ResponseWriter, r *http.Request) { - realIP = r.RemoteAddr - w.Write([]byte("Hello World")) - }) - r.ServeHTTP(w, req) - - if w.Code != 200 { - t.Fatal("Response Code should be 200") - } - - if realIP != "100.100.100.100" { - t.Fatal("Test get real IP precedence error.") - } -} - -func TestCustomIPHeaderWithoutDefault(t *testing.T) { - req, _ := http.NewRequest("GET", "/", nil) - req.Header.Add("X-REAL-IP", "100.100.100.100") - w := httptest.NewRecorder() - - r := chi.NewRouter() - r.Use(RealIPFromHeaders("CF-Connecting-IP")) - - realIP := "" - r.Get("/", func(w http.ResponseWriter, r *http.Request) { - realIP = r.RemoteAddr - w.Write([]byte("Hello World")) - }) - r.ServeHTTP(w, req) - - if w.Code != 200 { - t.Fatal("Response Code should be 200") - } - - if realIP != "" { - t.Fatal("Invalid IP used.") - } -}