From 142adad5c92ca9f90d809aa7e841516c8c2435e6 Mon Sep 17 00:00:00 2001 From: Pierre Fenoll Date: Tue, 31 May 2022 17:53:19 +0200 Subject: [PATCH] Handle port number variable of servers given to gorillamux.NewRouter (#524) --- routers/gorillamux/router.go | 88 ++++++++++++++++++++++++------- routers/gorillamux/router_test.go | 25 +++++++++ 2 files changed, 95 insertions(+), 18 deletions(-) diff --git a/routers/gorillamux/router.go b/routers/gorillamux/router.go index 67be47452..bf551a751 100644 --- a/routers/gorillamux/router.go +++ b/routers/gorillamux/router.go @@ -9,6 +9,7 @@ package gorillamux import ( "net/http" "net/url" + "regexp" "sort" "strings" @@ -22,32 +23,75 @@ var _ routers.Router = &Router{} // Router helps link http.Request.s and an OpenAPIv3 spec type Router struct { - muxes []*mux.Route + muxes []routeMux routes []*routers.Route } +type varsf func(vars map[string]string) + +type routeMux struct { + muxRoute *mux.Route + varsUpdater varsf +} + +var singleVariableMatcher = regexp.MustCompile(`^\{([^{}]+)\}$`) + +// TODO: Handle/HandlerFunc + ServeHTTP (When there is a match, the route variables can be retrieved calling mux.Vars(request)) + // NewRouter creates a gorilla/mux router. // Assumes spec is .Validate()d -// TODO: Handle/HandlerFunc + ServeHTTP (When there is a match, the route variables can be retrieved calling mux.Vars(request)) +// Note that a variable for the port number MUST have a default value and only this value will match as the port (see issue #367). func NewRouter(doc *openapi3.T) (routers.Router, error) { type srv struct { - schemes []string - host, base string - server *openapi3.Server + schemes []string + host, base string + server *openapi3.Server + varsUpdater varsf } servers := make([]srv, 0, len(doc.Servers)) for _, server := range doc.Servers { serverURL := server.URL + if submatch := singleVariableMatcher.FindStringSubmatch(serverURL); submatch != nil { + sVar := submatch[1] + sVal := server.Variables[sVar].Default + serverURL = strings.ReplaceAll(serverURL, "{"+sVar+"}", sVal) + var varsUpdater varsf + if lhs := strings.TrimSuffix(serverURL, server.Variables[sVar].Default); lhs != "" { + varsUpdater = func(vars map[string]string) { vars[sVar] = lhs } + } + servers = append(servers, srv{ + base: server.Variables[sVar].Default, + server: server, + varsUpdater: varsUpdater, + }) + continue + } + var schemes []string - var u *url.URL - var err error if strings.Contains(serverURL, "://") { scheme0 := strings.Split(serverURL, "://")[0] schemes = permutePart(scheme0, server) - u, err = url.Parse(bEncode(strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1))) - } else { - u, err = url.Parse(bEncode(serverURL)) + serverURL = strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1) } + + // If a variable represents the port "http://domain.tld:{port}/bla" + // then url.Parse() cannot parse "http://domain.tld:`bEncode({port})`/bla" + // and mux is not able to set the {port} variable + // So we just use the default value for this variable. + // See https://github.com/getkin/kin-openapi/issues/367 + var varsUpdater varsf + if lhs := strings.Index(serverURL, ":{"); lhs > 0 { + rest := serverURL[lhs+len(":{"):] + rhs := strings.Index(rest, "}") + portVariable := rest[:rhs] + portValue := server.Variables[portVariable].Default + serverURL = strings.ReplaceAll(serverURL, "{"+portVariable+"}", portValue) + varsUpdater = func(vars map[string]string) { + vars[portVariable] = portValue + } + } + + u, err := url.Parse(bEncode(serverURL)) if err != nil { return nil, err } @@ -56,10 +100,11 @@ func NewRouter(doc *openapi3.T) (routers.Router, error) { path = path[:len(path)-1] } servers = append(servers, srv{ - host: bDecode(u.Host), //u.Hostname()? - base: path, - schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624 - server: server, + host: bDecode(u.Host), //u.Hostname()? + base: path, + schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624 + server: server, + varsUpdater: varsUpdater, }) } if len(servers) == 0 { @@ -88,7 +133,10 @@ func NewRouter(doc *openapi3.T) (routers.Router, error) { if err := muxRoute.GetError(); err != nil { return nil, err } - r.muxes = append(r.muxes, muxRoute) + r.muxes = append(r.muxes, routeMux{ + muxRoute: muxRoute, + varsUpdater: s.varsUpdater, + }) r.routes = append(r.routes, &routers.Route{ Spec: doc, Server: s.server, @@ -104,16 +152,20 @@ func NewRouter(doc *openapi3.T) (routers.Router, error) { // FindRoute extracts the route and parameters of an http.Request func (r *Router) FindRoute(req *http.Request) (*routers.Route, map[string]string, error) { - for i, muxRoute := range r.muxes { + for i, m := range r.muxes { var match mux.RouteMatch - if muxRoute.Match(req, &match) { + if m.muxRoute.Match(req, &match) { if err := match.MatchErr; err != nil { // What then? } + vars := match.Vars + if f := m.varsUpdater; f != nil { + f(vars) + } route := *r.routes[i] route.Method = req.Method route.Operation = route.Spec.Paths[route.Path].GetOperation(route.Method) - return &route, match.Vars, nil + return &route, vars, nil } switch match.MatchErr { case nil: diff --git a/routers/gorillamux/router_test.go b/routers/gorillamux/router_test.go index 31bf416ed..1898db4ac 100644 --- a/routers/gorillamux/router_test.go +++ b/routers/gorillamux/router_test.go @@ -73,6 +73,7 @@ func TestRouter(t *testing.T) { } expect := func(r routers.Router, method string, uri string, operation *openapi3.Operation, params map[string]string) { + t.Helper() req, err := http.NewRequest(method, uri, nil) require.NoError(t, err) route, pathParams, err := r.FindRoute(req) @@ -164,6 +165,9 @@ func TestRouter(t *testing.T) { "d1": {Default: "example", Enum: []string{"example"}}, "scheme": {Default: "https", Enum: []string{"https", "http"}}, }}, + {URL: "http://127.0.0.1:{port}/api/v1", Variables: map[string]*openapi3.ServerVariable{ + "port": {Default: "8000"}, + }}, } err = doc.Validate(context.Background()) require.NoError(t, err) @@ -180,6 +184,20 @@ func TestRouter(t *testing.T) { "d1": "domain1", // "scheme": "https", TODO: https://github.com/gorilla/mux/issues/624 }) + expect(r, http.MethodGet, "http://127.0.0.1:8000/api/v1/hello", helloGET, map[string]string{ + "port": "8000", + }) + + doc.Servers = []*openapi3.Server{ + {URL: "{server}", Variables: map[string]*openapi3.ServerVariable{ + "server": {Default: "/api/v1"}, + }}, + } + err = doc.Validate(context.Background()) + require.NoError(t, err) + r, err = NewRouter(doc) + require.NoError(t, err) + expect(r, http.MethodGet, "https://myserver/api/v1/hello", helloGET, nil) { uri := "https://www.example.com/api/v1/onlyGET" @@ -224,6 +242,11 @@ func TestServerPath(t *testing.T) { func TestRelativeURL(t *testing.T) { helloGET := &openapi3.Operation{Responses: openapi3.NewResponses()} doc := &openapi3.T{ + OpenAPI: "3.0.0", + Info: &openapi3.Info{ + Title: "rel", + Version: "1", + }, Servers: openapi3.Servers{ &openapi3.Server{ URL: "/api/v1", @@ -235,6 +258,8 @@ func TestRelativeURL(t *testing.T) { }, }, } + err := doc.Validate(context.Background()) + require.NoError(t, err) router, err := NewRouter(doc) require.NoError(t, err) req, err := http.NewRequest(http.MethodGet, "https://example.com/api/v1/hello", nil)