diff --git a/routers/gorillamux/router.go b/routers/gorillamux/router.go index bf551a751..811ba7d16 100644 --- a/routers/gorillamux/router.go +++ b/routers/gorillamux/router.go @@ -34,6 +34,13 @@ type routeMux struct { varsUpdater varsf } +type srv struct { + schemes []string + host, base string + server *openapi3.Server + varsUpdater varsf +} + var singleVariableMatcher = regexp.MustCompile(`^\{([^{}]+)\}$`) // TODO: Handle/HandlerFunc + ServeHTTP (When there is a match, the route variables can be retrieved calling mux.Vars(request)) @@ -42,78 +49,22 @@ var singleVariableMatcher = regexp.MustCompile(`^\{([^{}]+)\}$`) // Assumes spec is .Validate()d // Note that a variable for the port number MUST have a default value and only this value will match as the port (see issue #367). func NewRouter(doc *openapi3.T) (routers.Router, error) { - type srv struct { - schemes []string - host, base string - server *openapi3.Server - varsUpdater varsf + servers, err := makeServers(doc.Servers) + if err != nil { + return nil, err } - servers := make([]srv, 0, len(doc.Servers)) - for _, server := range doc.Servers { - serverURL := server.URL - if submatch := singleVariableMatcher.FindStringSubmatch(serverURL); submatch != nil { - sVar := submatch[1] - sVal := server.Variables[sVar].Default - serverURL = strings.ReplaceAll(serverURL, "{"+sVar+"}", sVal) - var varsUpdater varsf - if lhs := strings.TrimSuffix(serverURL, server.Variables[sVar].Default); lhs != "" { - varsUpdater = func(vars map[string]string) { vars[sVar] = lhs } - } - servers = append(servers, srv{ - base: server.Variables[sVar].Default, - server: server, - varsUpdater: varsUpdater, - }) - continue - } - - var schemes []string - if strings.Contains(serverURL, "://") { - scheme0 := strings.Split(serverURL, "://")[0] - schemes = permutePart(scheme0, server) - serverURL = strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1) - } - // If a variable represents the port "http://domain.tld:{port}/bla" - // then url.Parse() cannot parse "http://domain.tld:`bEncode({port})`/bla" - // and mux is not able to set the {port} variable - // So we just use the default value for this variable. - // See https://github.com/getkin/kin-openapi/issues/367 - var varsUpdater varsf - if lhs := strings.Index(serverURL, ":{"); lhs > 0 { - rest := serverURL[lhs+len(":{"):] - rhs := strings.Index(rest, "}") - portVariable := rest[:rhs] - portValue := server.Variables[portVariable].Default - serverURL = strings.ReplaceAll(serverURL, "{"+portVariable+"}", portValue) - varsUpdater = func(vars map[string]string) { - vars[portVariable] = portValue - } - } - - u, err := url.Parse(bEncode(serverURL)) - if err != nil { - return nil, err - } - path := bDecode(u.EscapedPath()) - if len(path) > 0 && path[len(path)-1] == '/' { - path = path[:len(path)-1] - } - servers = append(servers, srv{ - host: bDecode(u.Host), //u.Hostname()? - base: path, - schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624 - server: server, - varsUpdater: varsUpdater, - }) - } - if len(servers) == 0 { - servers = append(servers, srv{}) - } muxRouter := mux.NewRouter().UseEncodedPath() r := &Router{} for _, path := range orderedPaths(doc.Paths) { + servers := servers + pathItem := doc.Paths[path] + if len(pathItem.Servers) > 0 { + if servers, err = makeServers(pathItem.Servers); err != nil { + return nil, err + } + } operations := pathItem.Operations() methods := make([]string, 0, len(operations)) @@ -177,6 +128,73 @@ func (r *Router) FindRoute(req *http.Request) (*routers.Route, map[string]string return nil, nil, routers.ErrPathNotFound } +func makeServers(in openapi3.Servers) ([]srv, error) { + servers := make([]srv, 0, len(in)) + for _, server := range in { + serverURL := server.URL + if submatch := singleVariableMatcher.FindStringSubmatch(serverURL); submatch != nil { + sVar := submatch[1] + sVal := server.Variables[sVar].Default + serverURL = strings.ReplaceAll(serverURL, "{"+sVar+"}", sVal) + var varsUpdater varsf + if lhs := strings.TrimSuffix(serverURL, server.Variables[sVar].Default); lhs != "" { + varsUpdater = func(vars map[string]string) { vars[sVar] = lhs } + } + servers = append(servers, srv{ + base: server.Variables[sVar].Default, + server: server, + varsUpdater: varsUpdater, + }) + continue + } + + var schemes []string + if strings.Contains(serverURL, "://") { + scheme0 := strings.Split(serverURL, "://")[0] + schemes = permutePart(scheme0, server) + serverURL = strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1) + } + + // If a variable represents the port "http://domain.tld:{port}/bla" + // then url.Parse() cannot parse "http://domain.tld:`bEncode({port})`/bla" + // and mux is not able to set the {port} variable + // So we just use the default value for this variable. + // See https://github.com/getkin/kin-openapi/issues/367 + var varsUpdater varsf + if lhs := strings.Index(serverURL, ":{"); lhs > 0 { + rest := serverURL[lhs+len(":{"):] + rhs := strings.Index(rest, "}") + portVariable := rest[:rhs] + portValue := server.Variables[portVariable].Default + serverURL = strings.ReplaceAll(serverURL, "{"+portVariable+"}", portValue) + varsUpdater = func(vars map[string]string) { + vars[portVariable] = portValue + } + } + + u, err := url.Parse(bEncode(serverURL)) + if err != nil { + return nil, err + } + path := bDecode(u.EscapedPath()) + if len(path) > 0 && path[len(path)-1] == '/' { + path = path[:len(path)-1] + } + servers = append(servers, srv{ + host: bDecode(u.Host), //u.Hostname()? + base: path, + schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624 + server: server, + varsUpdater: varsUpdater, + }) + } + if len(servers) == 0 { + servers = append(servers, srv{}) + } + + return servers, nil +} + func orderedPaths(paths map[string]*openapi3.PathItem) []string { // https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md#pathsObject // When matching URLs, concrete (non-templated) paths would be matched diff --git a/routers/gorillamux/router_test.go b/routers/gorillamux/router_test.go index f8800baed..104056e18 100644 --- a/routers/gorillamux/router_test.go +++ b/routers/gorillamux/router_test.go @@ -254,6 +254,47 @@ func TestServerPath(t *testing.T) { require.NoError(t, err) } +func TestServerOverrideAtPathLevel(t *testing.T) { + helloGET := &openapi3.Operation{Responses: openapi3.NewResponses()} + doc := &openapi3.T{ + OpenAPI: "3.0.0", + Info: &openapi3.Info{ + Title: "rel", + Version: "1", + }, + Servers: openapi3.Servers{ + &openapi3.Server{ + URL: "https://example.com", + }, + }, + Paths: openapi3.Paths{ + "/hello": &openapi3.PathItem{ + Servers: openapi3.Servers{ + &openapi3.Server{ + URL: "https://another.com", + }, + }, + Get: helloGET, + }, + }, + } + err := doc.Validate(context.Background()) + require.NoError(t, err) + router, err := NewRouter(doc) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "https://another.com/hello", nil) + require.NoError(t, err) + route, _, err := router.FindRoute(req) + require.Equal(t, "/hello", route.Path) + + req, err = http.NewRequest(http.MethodGet, "https://example.com/hello", nil) + require.NoError(t, err) + route, _, err = router.FindRoute(req) + require.Nil(t, route) + require.Error(t, err) +} + func TestRelativeURL(t *testing.T) { helloGET := &openapi3.Operation{Responses: openapi3.NewResponses()} doc := &openapi3.T{