diff --git a/proxy/http_integration_test.go b/proxy/http_integration_test.go index ff3627398..19973386b 100644 --- a/proxy/http_integration_test.go +++ b/proxy/http_integration_test.go @@ -177,6 +177,49 @@ func TestProxyHost(t *testing.T) { }) } +func TestRedirect(t *testing.T) { + routes := "route add mock / http://a.com/$path opts \"redirect=301\"\n" + routes += "route add mock /foo http://a.com/abc opts \"redirect=301\"\n" + routes += "route add mock /bar http://b.com/$path opts \"redirect=302 strip=/bar\"\n" + tbl, _ := route.NewTable(routes) + + proxy := httptest.NewServer(&HTTPProxy{ + Transport: http.DefaultTransport, + Lookup: func(r *http.Request) *route.Target { + return tbl.Lookup(r, "", route.Picker["rr"], route.Matcher["prefix"]) + }, + }) + defer proxy.Close() + + tests := []struct { + req string + wantCode int + wantLoc string + }{ + {req: "/", wantCode: 301, wantLoc: "http://a.com/"}, + {req: "/aaa/bbb", wantCode: 301, wantLoc: "http://a.com/aaa/bbb"}, + {req: "/foo", wantCode: 301, wantLoc: "http://a.com/abc"}, + {req: "/bar", wantCode: 302, wantLoc: "http://b.com"}, + {req: "/bar/aaa", wantCode: 302, wantLoc: "http://b.com/aaa"}, + } + + http.DefaultClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + // do not follow redirects + return http.ErrUseLastResponse + } + + for _, tt := range tests { + resp, _ := mustGet(proxy.URL + tt.req) + if resp.StatusCode != tt.wantCode { + t.Errorf("got status code %d, want %d", resp.StatusCode, tt.wantCode) + } + gotLoc, _ := resp.Location() + if gotLoc.String() != tt.wantLoc { + t.Errorf("got location %s, want %s", gotLoc, tt.wantLoc) + } + } +} + func TestProxyLogOutput(t *testing.T) { // build a format string from all log fields and one header field fields := []string{"header.X-Foo:$header.X-Foo"} diff --git a/proxy/http_proxy.go b/proxy/http_proxy.go index eb14b0c52..2c17237d0 100644 --- a/proxy/http_proxy.go +++ b/proxy/http_proxy.go @@ -60,6 +60,14 @@ func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { panic("no lookup function") } + if p.Config.RequestID != "" { + id := p.UUID + if id == nil { + id = uuid.NewUUID + } + r.Header.Set(p.Config.RequestID, id()) + } + t := p.Lookup(r) if t == nil { w.WriteHeader(p.Config.NoRouteStatus) @@ -75,6 +83,16 @@ func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { RawQuery: r.URL.RawQuery, } + if t.RedirectCode != 0 { + redirectURL := t.GetRedirectURL(requestURL) + http.Redirect(w, r, redirectURL.String(), t.RedirectCode) + if t.Timer != nil { + t.Timer.Update(0) + } + metrics.DefaultRegistry.GetTimer(key(t.RedirectCode)).Update(0) + return + } + // build the real target url that is passed to the proxy targetURL := &url.URL{ Scheme: t.URL.Scheme, @@ -106,14 +124,6 @@ func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if p.Config.RequestID != "" { - id := p.UUID - if id == nil { - id = uuid.NewUUID - } - r.Header.Set(p.Config.RequestID, id()) - } - upgrade, accept := r.Header.Get("Upgrade"), r.Header.Get("Accept") tr := p.Transport diff --git a/registry/consul/parse_test.go b/registry/consul/parse_test.go index 38d3ae274..6c3054a02 100644 --- a/registry/consul/parse_test.go +++ b/registry/consul/parse_test.go @@ -47,6 +47,12 @@ func TestParseTag(t *testing.T) { route: "xx/Yy", ok: true, }, + { + tag: "p-www.bar.com:80/foo redirect=302,https://www.bar.com", + route: "www.bar.com:80/foo", + opts: "redirect=302,https://www.bar.com", + ok: true, + }, } for i, tt := range tests { diff --git a/registry/consul/service.go b/registry/consul/service.go index 7b6d787b3..98df8fceb 100644 --- a/registry/consul/service.go +++ b/registry/consul/service.go @@ -1,6 +1,7 @@ package consul import ( + "fmt" "log" "net" "runtime" @@ -125,6 +126,15 @@ func serviceConfig(client *api.Client, name string, passing map[string]bool, tag dst = "https://" + addr case strings.HasPrefix(o, "weight="): weight = o[len("weight="):] + case strings.HasPrefix(o, "redirect="): + redir := strings.Split(o[len("redirect="):], ",") + if len(redir) == 2 { + dst = redir[1] + ropts = append(ropts, fmt.Sprintf("redirect=%s", redir[0])) + } else { + log.Printf("[ERROR] Invalid syntax for redirect: %s. should be redirect=,", o) + continue + } default: ropts = append(ropts, o) } diff --git a/route/route.go b/route/route.go index 152c4d30e..f0659c923 100644 --- a/route/route.go +++ b/route/route.go @@ -6,6 +6,7 @@ import ( "net/url" "reflect" "sort" + "strconv" "strings" "github.com/fabiolb/fabio/metrics" @@ -68,6 +69,16 @@ func (r *Route) addTarget(service string, targetURL *url.URL, fixedWeight float6 t.StripPath = opts["strip"] t.TLSSkipVerify = opts["tlsskipverify"] == "true" t.Host = opts["host"] + + if opts["redirect"] != "" { + t.RedirectCode, err = strconv.Atoi(opts["redirect"]) + if err != nil { + log.Printf("[ERROR] redirect status code should be numeric in 3xx range. Got: %s", opts["redirect"]) + } else if t.RedirectCode < 300 || t.RedirectCode > 399 { + t.RedirectCode = 0 + log.Printf("[ERROR] redirect status code should be in 3xx range. Got: %s", opts["redirect"]) + } + } } r.Targets = append(r.Targets, t) diff --git a/route/table.go b/route/table.go index 054185eb5..6a23021d4 100644 --- a/route/table.go +++ b/route/table.go @@ -273,12 +273,12 @@ func (t Table) route(host, path string) *Route { // normalizeHost returns the hostname from the request // and removes the default port if present. -func normalizeHost(req *http.Request) string { - host := strings.ToLower(req.Host) - if req.TLS == nil && strings.HasSuffix(host, ":80") { +func normalizeHost(host string, tls bool) string { + host = strings.ToLower(host) + if !tls && strings.HasSuffix(host, ":80") { return host[:len(host)-len(":80")] } - if req.TLS != nil && strings.HasSuffix(host, ":443") { + if tls && strings.HasSuffix(host, ":443") { return host[:len(host)-len(":443")] } return host @@ -287,9 +287,10 @@ func normalizeHost(req *http.Request) string { // matchingHosts returns all keys (host name patterns) from the // routing table which match the normalized request hostname. func (t Table) matchingHosts(req *http.Request) (hosts []string) { - host := normalizeHost(req) + host := normalizeHost(req.Host, req.TLS != nil) for pattern := range t { - if glob.Glob(pattern, host) { + normpat := normalizeHost(pattern, req.TLS != nil) + if glob.Glob(normpat, host) { hosts = append(hosts, pattern) } } diff --git a/route/table_test.go b/route/table_test.go index c631cad37..eb0fdea71 100644 --- a/route/table_test.go +++ b/route/table_test.go @@ -477,7 +477,7 @@ func TestNormalizeHost(t *testing.T) { } for i, tt := range tests { - if got, want := normalizeHost(tt.req), tt.host; got != want { + if got, want := normalizeHost(tt.req.Host, tt.req.TLS != nil), tt.host; got != want { t.Errorf("%d: got %v want %v", i, got, want) } } @@ -495,6 +495,7 @@ func TestTableLookup(t *testing.T) { route add svc z.abc.com/foo/ http://foo.com:3100 route add svc *.abc.com/ http://foo.com:4000 route add svc *.abc.com/foo/ http://foo.com:5000 + route add svc xyz.com:80/ https://xyz.com ` tbl, err := NewTable(s) @@ -539,6 +540,9 @@ func TestTableLookup(t *testing.T) { // exact match has precedence over glob match {&http.Request{Host: "z.abc.com", URL: mustParse("/foo/")}, "http://foo.com:3100"}, + + // explicit port on route + {&http.Request{Host: "xyz.com", URL: mustParse("/")}, "https://xyz.com"}, } for i, tt := range tests { diff --git a/route/target.go b/route/target.go index 14b0a3a45..6f5a70034 100644 --- a/route/target.go +++ b/route/target.go @@ -2,6 +2,7 @@ package route import ( "net/url" + "strings" "github.com/fabiolb/fabio/metrics" ) @@ -33,6 +34,10 @@ type Target struct { // URL is the endpoint the service instance listens on URL *url.URL + // RedirectCode is the HTTP status code used for redirects. + // When set to a value > 0 the client is redirected to the target url. + RedirectCode int + // FixedWeight is the weight assigned to this target. // If the value is 0 the targets weight is dynamic. FixedWeight float64 @@ -46,3 +51,29 @@ type Target struct { // TimerName is the name of the timer in the metrics registry TimerName string } + +func (t *Target) GetRedirectURL(requestURL *url.URL) *url.URL { + redirectURL := &url.URL{ + Scheme: t.URL.Scheme, + Host: t.URL.Host, + Path: t.URL.Path, + RawQuery: t.URL.RawQuery, + } + if strings.HasSuffix(redirectURL.Host, "$path") { + redirectURL.Host = redirectURL.Host[:len(redirectURL.Host)-len("$path")] + redirectURL.Path = "$path" + } + if strings.Contains(redirectURL.Path, "/$path") { + redirectURL.Path = strings.Replace(redirectURL.Path, "/$path", "$path", 1) + } + if strings.Contains(redirectURL.Path, "$path") { + redirectURL.Path = strings.Replace(redirectURL.Path, "$path", requestURL.Path, 1) + if t.StripPath != "" && strings.HasPrefix(redirectURL.Path, t.StripPath) { + redirectURL.Path = redirectURL.Path[len(t.StripPath):] + } + if redirectURL.RawQuery == "" && requestURL.RawQuery != "" { + redirectURL.RawQuery = requestURL.RawQuery + } + } + return redirectURL +} diff --git a/route/target_test.go b/route/target_test.go new file mode 100644 index 000000000..7d9c7794b --- /dev/null +++ b/route/target_test.go @@ -0,0 +1,104 @@ +package route + +import ( + "net/url" + "testing" +) + +func TestTarget_GetRedirectURL(t *testing.T) { + type routeTest struct { + req string + want string + } + tests := []struct { + route string + tests []routeTest + }{ + { // simple absolute redirect + route: "route add svc / http://bar.com/", + tests: []routeTest{ + {req: "/", want: "http://bar.com/"}, + {req: "/abc", want: "http://bar.com/"}, + {req: "/a/b/c", want: "http://bar.com/"}, + {req: "/?aaa=1", want: "http://bar.com/"}, + }, + }, + { // absolute redirect to deep path with query + route: "route add svc / http://bar.com/a/b/c?foo=bar", + tests: []routeTest{ + {req: "/", want: "http://bar.com/a/b/c?foo=bar"}, + {req: "/abc", want: "http://bar.com/a/b/c?foo=bar"}, + {req: "/a/b/c", want: "http://bar.com/a/b/c?foo=bar"}, + {req: "/?aaa=1", want: "http://bar.com/a/b/c?foo=bar"}, + }, + }, + { // simple redirect to corresponding path + route: "route add svc / http://bar.com/$path", + tests: []routeTest{ + {req: "/", want: "http://bar.com/"}, + {req: "/abc", want: "http://bar.com/abc"}, + {req: "/a/b/c", want: "http://bar.com/a/b/c"}, + {req: "/?aaa=1", want: "http://bar.com/?aaa=1"}, + {req: "/abc/?aaa=1", want: "http://bar.com/abc/?aaa=1"}, + }, + }, + { // same as above but without / before $path + route: "route add svc / http://bar.com$path", + tests: []routeTest{ + {req: "/", want: "http://bar.com/"}, + {req: "/abc", want: "http://bar.com/abc"}, + {req: "/a/b/c", want: "http://bar.com/a/b/c"}, + {req: "/?aaa=1", want: "http://bar.com/?aaa=1"}, + {req: "/abc/?aaa=1", want: "http://bar.com/abc/?aaa=1"}, + }, + }, + { // arbitrary subdir on target with $path at end + route: "route add svc / http://bar.com/bbb/$path", + tests: []routeTest{ + {req: "/", want: "http://bar.com/bbb/"}, + {req: "/abc", want: "http://bar.com/bbb/abc"}, + {req: "/a/b/c", want: "http://bar.com/bbb/a/b/c"}, + {req: "/?aaa=1", want: "http://bar.com/bbb/?aaa=1"}, + {req: "/abc/?aaa=1", want: "http://bar.com/bbb/abc/?aaa=1"}, + }, + }, + { // same as above but without / before $path + route: "route add svc / http://bar.com/bbb$path", + tests: []routeTest{ + {req: "/", want: "http://bar.com/bbb/"}, + {req: "/abc", want: "http://bar.com/bbb/abc"}, + {req: "/a/b/c", want: "http://bar.com/bbb/a/b/c"}, + {req: "/?aaa=1", want: "http://bar.com/bbb/?aaa=1"}, + {req: "/abc/?aaa=1", want: "http://bar.com/bbb/abc/?aaa=1"}, + }, + }, + { // strip prefix + route: "route add svc /stripme http://bar.com/$path opts \"strip=/stripme\"", + tests: []routeTest{ + {req: "/stripme/", want: "http://bar.com/"}, + {req: "/stripme/abc", want: "http://bar.com/abc"}, + {req: "/stripme/a/b/c", want: "http://bar.com/a/b/c"}, + {req: "/stripme/?aaa=1", want: "http://bar.com/?aaa=1"}, + {req: "/stripme/abc/?aaa=1", want: "http://bar.com/abc/?aaa=1"}, + }, + }, + } + firstRoute := func(tbl Table) *Route { + for _, routes := range tbl { + return routes[0] + } + return nil + } + for _, tt := range tests { + tbl, _ := NewTable(tt.route) + route := firstRoute(tbl) + target := route.Targets[0] + for _, rt := range tt.tests { + reqURL, _ := url.Parse("http://foo.com" + rt.req) + got := target.GetRedirectURL(reqURL) + if got.String() != rt.want { + t.Errorf("Got %s, wanted %s", got, rt.want) + } + } + } +}