Skip to content

Commit

Permalink
Handle port number variable of servers given to gorillamux.NewRouter (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
fenollp committed May 31, 2022
1 parent 12540af commit 142adad
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 18 deletions.
88 changes: 70 additions & 18 deletions routers/gorillamux/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package gorillamux
import (
"net/http"
"net/url"
"regexp"
"sort"
"strings"

Expand All @@ -22,32 +23,75 @@ var _ routers.Router = &Router{}

// Router helps link http.Request.s and an OpenAPIv3 spec
type Router struct {
muxes []*mux.Route
muxes []routeMux
routes []*routers.Route
}

type varsf func(vars map[string]string)

type routeMux struct {
muxRoute *mux.Route
varsUpdater varsf
}

var singleVariableMatcher = regexp.MustCompile(`^\{([^{}]+)\}$`)

// TODO: Handle/HandlerFunc + ServeHTTP (When there is a match, the route variables can be retrieved calling mux.Vars(request))

// NewRouter creates a gorilla/mux router.
// Assumes spec is .Validate()d
// TODO: Handle/HandlerFunc + ServeHTTP (When there is a match, the route variables can be retrieved calling mux.Vars(request))
// Note that a variable for the port number MUST have a default value and only this value will match as the port (see issue #367).
func NewRouter(doc *openapi3.T) (routers.Router, error) {
type srv struct {
schemes []string
host, base string
server *openapi3.Server
schemes []string
host, base string
server *openapi3.Server
varsUpdater varsf
}
servers := make([]srv, 0, len(doc.Servers))
for _, server := range doc.Servers {
serverURL := server.URL
if submatch := singleVariableMatcher.FindStringSubmatch(serverURL); submatch != nil {
sVar := submatch[1]
sVal := server.Variables[sVar].Default
serverURL = strings.ReplaceAll(serverURL, "{"+sVar+"}", sVal)
var varsUpdater varsf
if lhs := strings.TrimSuffix(serverURL, server.Variables[sVar].Default); lhs != "" {
varsUpdater = func(vars map[string]string) { vars[sVar] = lhs }
}
servers = append(servers, srv{
base: server.Variables[sVar].Default,
server: server,
varsUpdater: varsUpdater,
})
continue
}

var schemes []string
var u *url.URL
var err error
if strings.Contains(serverURL, "://") {
scheme0 := strings.Split(serverURL, "://")[0]
schemes = permutePart(scheme0, server)
u, err = url.Parse(bEncode(strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1)))
} else {
u, err = url.Parse(bEncode(serverURL))
serverURL = strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1)
}

// If a variable represents the port "http://domain.tld:{port}/bla"
// then url.Parse() cannot parse "http://domain.tld:`bEncode({port})`/bla"
// and mux is not able to set the {port} variable
// So we just use the default value for this variable.
// See https://github.com/getkin/kin-openapi/issues/367
var varsUpdater varsf
if lhs := strings.Index(serverURL, ":{"); lhs > 0 {
rest := serverURL[lhs+len(":{"):]
rhs := strings.Index(rest, "}")
portVariable := rest[:rhs]
portValue := server.Variables[portVariable].Default
serverURL = strings.ReplaceAll(serverURL, "{"+portVariable+"}", portValue)
varsUpdater = func(vars map[string]string) {
vars[portVariable] = portValue
}
}

u, err := url.Parse(bEncode(serverURL))
if err != nil {
return nil, err
}
Expand All @@ -56,10 +100,11 @@ func NewRouter(doc *openapi3.T) (routers.Router, error) {
path = path[:len(path)-1]
}
servers = append(servers, srv{
host: bDecode(u.Host), //u.Hostname()?
base: path,
schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624
server: server,
host: bDecode(u.Host), //u.Hostname()?
base: path,
schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624
server: server,
varsUpdater: varsUpdater,
})
}
if len(servers) == 0 {
Expand Down Expand Up @@ -88,7 +133,10 @@ func NewRouter(doc *openapi3.T) (routers.Router, error) {
if err := muxRoute.GetError(); err != nil {
return nil, err
}
r.muxes = append(r.muxes, muxRoute)
r.muxes = append(r.muxes, routeMux{
muxRoute: muxRoute,
varsUpdater: s.varsUpdater,
})
r.routes = append(r.routes, &routers.Route{
Spec: doc,
Server: s.server,
Expand All @@ -104,16 +152,20 @@ func NewRouter(doc *openapi3.T) (routers.Router, error) {

// FindRoute extracts the route and parameters of an http.Request
func (r *Router) FindRoute(req *http.Request) (*routers.Route, map[string]string, error) {
for i, muxRoute := range r.muxes {
for i, m := range r.muxes {
var match mux.RouteMatch
if muxRoute.Match(req, &match) {
if m.muxRoute.Match(req, &match) {
if err := match.MatchErr; err != nil {
// What then?
}
vars := match.Vars
if f := m.varsUpdater; f != nil {
f(vars)
}
route := *r.routes[i]
route.Method = req.Method
route.Operation = route.Spec.Paths[route.Path].GetOperation(route.Method)
return &route, match.Vars, nil
return &route, vars, nil
}
switch match.MatchErr {
case nil:
Expand Down
25 changes: 25 additions & 0 deletions routers/gorillamux/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func TestRouter(t *testing.T) {
}

expect := func(r routers.Router, method string, uri string, operation *openapi3.Operation, params map[string]string) {
t.Helper()
req, err := http.NewRequest(method, uri, nil)
require.NoError(t, err)
route, pathParams, err := r.FindRoute(req)
Expand Down Expand Up @@ -164,6 +165,9 @@ func TestRouter(t *testing.T) {
"d1": {Default: "example", Enum: []string{"example"}},
"scheme": {Default: "https", Enum: []string{"https", "http"}},
}},
{URL: "http://127.0.0.1:{port}/api/v1", Variables: map[string]*openapi3.ServerVariable{
"port": {Default: "8000"},
}},
}
err = doc.Validate(context.Background())
require.NoError(t, err)
Expand All @@ -180,6 +184,20 @@ func TestRouter(t *testing.T) {
"d1": "domain1",
// "scheme": "https", TODO: https://github.com/gorilla/mux/issues/624
})
expect(r, http.MethodGet, "http://127.0.0.1:8000/api/v1/hello", helloGET, map[string]string{
"port": "8000",
})

doc.Servers = []*openapi3.Server{
{URL: "{server}", Variables: map[string]*openapi3.ServerVariable{
"server": {Default: "/api/v1"},
}},
}
err = doc.Validate(context.Background())
require.NoError(t, err)
r, err = NewRouter(doc)
require.NoError(t, err)
expect(r, http.MethodGet, "https://myserver/api/v1/hello", helloGET, nil)

{
uri := "https://www.example.com/api/v1/onlyGET"
Expand Down Expand Up @@ -224,6 +242,11 @@ func TestServerPath(t *testing.T) {
func TestRelativeURL(t *testing.T) {
helloGET := &openapi3.Operation{Responses: openapi3.NewResponses()}
doc := &openapi3.T{
OpenAPI: "3.0.0",
Info: &openapi3.Info{
Title: "rel",
Version: "1",
},
Servers: openapi3.Servers{
&openapi3.Server{
URL: "/api/v1",
Expand All @@ -235,6 +258,8 @@ func TestRelativeURL(t *testing.T) {
},
},
}
err := doc.Validate(context.Background())
require.NoError(t, err)
router, err := NewRouter(doc)
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, "https://example.com/api/v1/hello", nil)
Expand Down

0 comments on commit 142adad

Please sign in to comment.