Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce verbosity in GenerateRoutes #3175

Merged
merged 2 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internal/server/apimigrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,18 @@ func (m *apiMigration) RedirectHandler() gin.HandlerFunc {

type HTTPMethodBindFunc func(relativePath string, handlers ...gin.HandlerFunc) gin.IRoutes

func bindRoute(a *API, r *gin.RouterGroup, method, path string, handler gin.HandlerFunc) {
func bindRoute(a *API, r *gin.RouterGroup, routeID routeIdentifier, handler gin.HandlerFunc) {
// build up the handlers into a map of all the paths we need to bind into.
routes := map[string][]gin.HandlerFunc{}
// set the default path
routes[path] = []gin.HandlerFunc{handler}
routes[routeID.path] = []gin.HandlerFunc{handler}

// we're going to build this list in referse order, prepending middleware.
// we start with the current migration and prepend versions backwards,
// 0.1.3, then 0.1.2, then 0.1.1.
sort.Slice(a.migrations, sortVersionDescendingOrder(a.migrations))
for _, migration := range a.migrations {
if strings.ToUpper(migration.method) != method {
if strings.ToUpper(migration.method) != routeID.method {
continue
}

Expand Down Expand Up @@ -274,7 +274,7 @@ func bindRoute(a *API, r *gin.RouterGroup, method, path string, handler gin.Hand

// now bind all relevant paths with Gin
for path, handlers := range routes {
r.Handle(method, path, handlers...)
r.Handle(routeID.method, path, handlers...)
}
}

Expand Down
7 changes: 7 additions & 0 deletions internal/server/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ import (
"github.com/infrahq/infra/internal/server/models"
)

var pprofRoute = route[api.EmptyRequest, *api.EmptyResponse]{
handler: pprofHandler,
omitFromTelemetry: true,
omitFromDocs: true,
infraVersionHeaderOptional: true,
}

func pprofHandler(c *gin.Context, _ *api.EmptyRequest) (*api.EmptyResponse, error) {
if _, err := access.RequireInfraRole(c, models.InfraSupportAdminRole); err != nil {
return nil, access.HandleAuthErr(err, "debug", "run", models.InfraSupportAdminRole)
Expand Down
11 changes: 9 additions & 2 deletions internal/server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ func (a *API) CreateToken(c *gin.Context, r *api.EmptyRequest) (*api.CreateToken
return nil, fmt.Errorf("%w: no identity found in access key", internal.ErrUnauthorized)
}

type WellKnownJWKResponse struct {
Keys []jose.JSONWebKey `json:"keys"`
var wellKnownJWKsRoute = route[api.EmptyRequest, WellKnownJWKResponse]{
handler: wellKnownJWKsHandler,
omitFromDocs: true,
omitFromTelemetry: true,
infraVersionHeaderOptional: true,
}

func wellKnownJWKsHandler(c *gin.Context, _ *api.EmptyRequest) (WellKnownJWKResponse, error) {
Expand All @@ -63,6 +66,10 @@ func wellKnownJWKsHandler(c *gin.Context, _ *api.EmptyRequest) (WellKnownJWKResp
return WellKnownJWKResponse{Keys: keys}, nil
}

type WellKnownJWKResponse struct {
Keys []jose.JSONWebKey `json:"keys"`
}

func (a *API) ListAccessKeys(c *gin.Context, r *api.ListAccessKeysRequest) (*api.ListResponse[api.AccessKey], error) {
p := PaginationFromRequest(r.PaginationRequest)
accessKeys, err := access.ListAccessKeys(c, r.UserID, r.Name, r.ShowExpired, &p)
Expand Down
4 changes: 2 additions & 2 deletions internal/server/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ var funcPartialNameToTagNames = []struct {
// openAPIRouteDefinition converts the route into a format that can be used
// by API.register. This is necessary because currently methods can not have
// generic parameters.
func openAPIRouteDefinition[Req, Res any](route route[Req, Res]) (
func openAPIRouteDefinition[Req, Res any](routeID routeIdentifier, route route[Req, Res]) (
method string,
path string,
funcName string,
Expand All @@ -57,7 +57,7 @@ func openAPIRouteDefinition[Req, Res any](route route[Req, Res]) (
) {
//nolint:gocritic
reqT, resultT := reflect.TypeOf(*new(Req)), reflect.TypeOf(*new(Res))
return route.method, route.path, getFuncName(route.handler), reqT, resultT
return routeID.method, routeID.path, getFuncName(route.handler), reqT, resultT
}

// register adds the route to the API.OpenAPIDocument.
Expand Down
62 changes: 24 additions & 38 deletions internal/server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/getkin/kin-openapi/openapi3"
"github.com/gin-gonic/gin"

"github.com/infrahq/infra/api"
"github.com/infrahq/infra/internal"
"github.com/infrahq/infra/internal/logging"
"github.com/infrahq/infra/internal/validate"
Expand Down Expand Up @@ -97,14 +96,7 @@ func (s *Server) GenerateRoutes() Routes {

put(a, authn, "/api/settings", a.UpdateSettings)

add(a, authn, route[api.EmptyRequest, *api.EmptyResponse]{
method: http.MethodGet,
path: "/api/debug/pprof/*profile",
handler: pprofHandler,
omitFromTelemetry: true,
omitFromDocs: true,
infraVersionHeaderOptional: true,
})
add(a, authn, http.MethodGet, "/api/debug/pprof/*profile", pprofRoute)

// no auth required, org not required
noAuthnNoOrg := &routeGroup{RouterGroup: apiGroup.Group("/"), noAuthentication: true, noOrgRequired: true}
Expand All @@ -124,15 +116,7 @@ func (s *Server) GenerateRoutes() Routes {
get(a, noAuthnWithOrg, "/api/providers", a.ListProviders)
get(a, noAuthnWithOrg, "/api/settings", a.GetSettings)

// no auth required, org required, undocumented in api spec
add(a, noAuthnWithOrg, route[api.EmptyRequest, WellKnownJWKResponse]{
method: http.MethodGet,
path: "/.well-known/jwks.json",
handler: wellKnownJWKsHandler,
omitFromDocs: true,
omitFromTelemetry: true,
infraVersionHeaderOptional: true,
})
add(a, noAuthnWithOrg, http.MethodGet, "/.well-known/jwks.json", wellKnownJWKsRoute)

a.deprecatedRoutes(noAuthnNoOrg)

Expand All @@ -148,8 +132,6 @@ func (s *Server) GenerateRoutes() Routes {
type HandlerFunc[Req, Res any] func(c *gin.Context, req *Req) (Res, error)

type route[Req, Res any] struct {
method string
path string
handler HandlerFunc[Req, Res]
omitFromDocs bool
omitFromTelemetry bool
Expand All @@ -158,6 +140,11 @@ type route[Req, Res any] struct {
noOrgRequired bool
}

type routeIdentifier struct {
method string
path string
}

// TODO: replace this when routes are defined as package-level vars instead of
// constructed from the get, post, put, del helper functions.
type routeGroup struct {
Expand All @@ -166,22 +153,25 @@ type routeGroup struct {
noOrgRequired bool
}

func add[Req, Res any](a *API, group *routeGroup, route route[Req, Res]) {
route.path = path.Join(group.BasePath(), route.path)
func add[Req, Res any](a *API, group *routeGroup, method, urlPath string, route route[Req, Res]) {
routeID := routeIdentifier{
method: method,
path: path.Join(group.BasePath(), urlPath),
}

if !route.omitFromDocs {
a.register(openAPIRouteDefinition(route))
a.register(openAPIRouteDefinition(routeID, route))
}

route.noAuthentication = group.noAuthentication
route.noOrgRequired = group.noOrgRequired

handler := func(c *gin.Context) {
if err := wrapRoute(a, route)(c); err != nil {
if err := wrapRoute(a, routeID, route)(c); err != nil {
sendAPIError(c, err)
}
}
bindRoute(a, group.RouterGroup, route.method, route.path, handler)
bindRoute(a, group.RouterGroup, routeID, handler)
}

// wrapRoute builds a gin.HandlerFunc from a route. The returned function
Expand All @@ -191,7 +181,7 @@ func add[Req, Res any](a *API, group *routeGroup, route route[Req, Res]) {
// a request scoped database transaction, authenticates the request, reads the
// request fields into a request struct, and returns an HTTP response with a
// status code and response body built from the response type.
func wrapRoute[Req, Res any](a *API, route route[Req, Res]) func(*gin.Context) error {
func wrapRoute[Req, Res any](a *API, routeID routeIdentifier, route route[Req, Res]) func(*gin.Context) error {
return func(c *gin.Context) error {
if !route.infraVersionHeaderOptional {
if _, err := requestVersion(c.Request); err != nil {
Expand Down Expand Up @@ -239,10 +229,10 @@ func wrapRoute[Req, Res any](a *API, route route[Req, Res]) func(*gin.Context) e
}

if !route.omitFromTelemetry {
a.t.RouteEvent(c, route.path, Properties{"method": strings.ToLower(route.method)})
a.t.RouteEvent(c, routeID.path, Properties{"method": strings.ToLower(routeID.method)})
}

c.JSON(responseStatusCode(route.method, resp), resp)
c.JSON(responseStatusCode(routeID.method, resp), resp)
return nil
}
}
Expand Down Expand Up @@ -284,34 +274,30 @@ func responseStatusCode(method string, resp any) int {
}

func get[Req, Res any](a *API, r *routeGroup, path string, handler HandlerFunc[Req, Res]) {
add(a, r, route[Req, Res]{
method: http.MethodGet,
path: path,
add(a, r, http.MethodGet, path, route[Req, Res]{
handler: handler,
omitFromTelemetry: true,
})
}

func post[Req, Res any](a *API, r *routeGroup, path string, handler HandlerFunc[Req, Res]) {
add(a, r, route[Req, Res]{method: http.MethodPost, path: path, handler: handler})
add(a, r, http.MethodPost, path, route[Req, Res]{handler: handler})
}

func put[Req, Res any](a *API, r *routeGroup, path string, handler HandlerFunc[Req, Res]) {
add(a, r, route[Req, Res]{method: http.MethodPut, path: path, handler: handler})
add(a, r, http.MethodPut, path, route[Req, Res]{handler: handler})
}

func patch[Req, Res any](a *API, r *routeGroup, path string, handler HandlerFunc[Req, Res]) {
add(a, r, route[Req, Res]{method: http.MethodPatch, path: path, handler: handler})
add(a, r, http.MethodPatch, path, route[Req, Res]{handler: handler})
}

func del[Req any, Res any](a *API, r *routeGroup, path string, handler HandlerFunc[Req, Res]) {
add(a, r, route[Req, Res]{method: http.MethodDelete, path: path, handler: handler})
add(a, r, http.MethodDelete, path, route[Req, Res]{handler: handler})
}

func addDeprecated[Req, Res any](a *API, r *routeGroup, method string, path string, handler HandlerFunc[Req, Res]) {
add(a, r, route[Req, Res]{
method: method,
path: path,
add(a, r, method, path, route[Req, Res]{
handler: handler,
omitFromTelemetry: true,
omitFromDocs: true,
Expand Down
8 changes: 2 additions & 6 deletions internal/server/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ func TestWrapRoute_TxnRollbackOnError(t *testing.T) {
router := gin.New()

r := route[api.EmptyRequest, *api.EmptyResponse]{
method: "POST",
path: "/do",
handler: func(c *gin.Context, request *api.EmptyRequest) (*api.EmptyResponse, error) {
rCtx := getRequestContext(c)

Expand All @@ -213,7 +211,7 @@ func TestWrapRoute_TxnRollbackOnError(t *testing.T) {
}

api := &API{server: srv}
add(api, rg(router.Group("/")), r)
add(api, rg(router.Group("/")), "POST", "/do", r)

resp := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/do", nil)
Expand All @@ -233,8 +231,6 @@ func TestWrapRoute_HandleErrorOnCommit(t *testing.T) {
router := gin.New()

r := route[api.EmptyRequest, *api.EmptyResponse]{
method: "POST",
path: "/do",
handler: func(c *gin.Context, request *api.EmptyRequest) (*api.EmptyResponse, error) {
rCtx := getRequestContext(c)

Expand All @@ -248,7 +244,7 @@ func TestWrapRoute_HandleErrorOnCommit(t *testing.T) {
}

api := &API{server: srv}
add(api, rg(router.Group("/")), r)
add(api, rg(router.Group("/")), "POST", "/do", r)

resp := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/do", nil)
Expand Down