Skip to content

Commit

Permalink
openAPIValidate function in Backend; store requestValidationInput in …
Browse files Browse the repository at this point in the history
…local variable
  • Loading branch information
Johannes Koch committed May 17, 2021
1 parent 918bf92 commit c771404
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 33 deletions.
62 changes: 42 additions & 20 deletions handler/transport/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,6 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Del(AcceptEncodingHeader)
}

if b.openAPIValidator != nil {
if err = b.openAPIValidator.ValidateRequest(req); err != nil {
return nil, errors.BackendValidation.Label(b.name).With(err)
}
}

if xff, ok := req.Context().Value(request.XFF).(string); ok {
if xff != "" {
req.Header.Set("X-Forwarded-For", xff)
Expand All @@ -128,22 +122,15 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := context.WithValue(req.Context(), request.BackendURL, req.URL.String())
*req = *req.WithContext(ctx)

beresp, err := t.RoundTrip(req)
if err != nil {
select {
case derr := <-deadlineErr:
if derr != nil {
return nil, derr
}
default:
return nil, errors.Backend.Label(b.name).With(err)
}
var beresp *http.Response
if b.openAPIValidator != nil {
beresp, err = b.openAPIValidate(req, t, deadlineErr)
} else {
beresp, err = b.innerRoundTrip(req, t, deadlineErr)
}

if b.openAPIValidator != nil {
if err = b.openAPIValidator.ValidateResponse(beresp); err != nil {
return nil, errors.BackendValidation.Label(b.name).With(err).Status(http.StatusBadGateway)
}
if err != nil {
return nil, err
}

if strings.ToLower(beresp.Header.Get(ContentEncodingHeader)) == GzipName {
Expand Down Expand Up @@ -171,6 +158,41 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
return beresp, err
}

func (b *Backend) openAPIValidate(req *http.Request, t *http.Transport, deadlineErr <-chan error) (*http.Response, error) {
requestValidationInput, err := b.openAPIValidator.ValidateRequest(req)
if err != nil {
return nil, errors.BackendValidation.Label(b.name).With(err)
}

beresp, err := b.innerRoundTrip(req, t, deadlineErr)
if err != nil {
return nil, err
}

if err = b.openAPIValidator.ValidateResponse(beresp, requestValidationInput); err != nil {
return nil, errors.BackendValidation.Label(b.name).With(err).Status(http.StatusBadGateway)
}

return beresp, nil
}

func (b *Backend) innerRoundTrip(req *http.Request, t *http.Transport, deadlineErr <-chan error) (*http.Response, error) {
beresp, err := t.RoundTrip(req)

if err != nil {
select {
case derr := <-deadlineErr:
if derr != nil {
return nil, derr
}
default:
return nil, errors.Backend.Label(b.name).With(err)
}
}

return beresp, nil
}

func (b *Backend) withPathPrefix(req *http.Request) {
if pathPrefix := b.getAttribute(req, "path_prefix"); pathPrefix != "" {
req.URL.Path = utils.JoinPath("/", pathPrefix, req.URL.Path)
Expand Down
24 changes: 11 additions & 13 deletions handler/validation/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ import (
)

type OpenAPI struct {
options *OpenAPIOptions
requestValidationInput *openapi3filter.RequestValidationInput
options *OpenAPIOptions
}

func NewOpenAPI(opts *OpenAPIOptions) *OpenAPI {
Expand All @@ -27,21 +26,20 @@ func NewOpenAPI(opts *OpenAPIOptions) *OpenAPI {
options: opts,
}
}

func (v *OpenAPI) ValidateRequest(req *http.Request) error {
func (v *OpenAPI) ValidateRequest(req *http.Request) (*openapi3filter.RequestValidationInput, error) {
route, pathParams, err := v.options.router.FindRoute(req.Method, req.URL)
if err != nil {
err = fmt.Errorf("request validation: '%s %s': %w", req.Method, req.URL.Path, err)
if ctx, ok := req.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
ctx.errors = append(ctx.errors, err)
}
if !v.options.ignoreRequestViolations {
return err
return nil, err
}
return nil
return nil, nil
}

v.requestValidationInput = &openapi3filter.RequestValidationInput{
requestValidationInput := &openapi3filter.RequestValidationInput{
Options: v.options.filterOptions,
PathParams: pathParams,
QueryParams: req.URL.Query(),
Expand All @@ -50,24 +48,24 @@ func (v *OpenAPI) ValidateRequest(req *http.Request) error {
}

// openapi3filter.ValidateRequestBody also handles resetting the req body after reading until EOF.
err = openapi3filter.ValidateRequest(req.Context(), v.requestValidationInput)
err = openapi3filter.ValidateRequest(req.Context(), requestValidationInput)

if err != nil {
err = fmt.Errorf("request validation: %w", err)
if ctx, ok := req.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
ctx.errors = append(ctx.errors, err)
}
if !v.options.ignoreRequestViolations {
return err
return requestValidationInput, err
}
}

return nil
return requestValidationInput, nil
}

func (v *OpenAPI) ValidateResponse(beresp *http.Response) error {
func (v *OpenAPI) ValidateResponse(beresp *http.Response, requestValidationInput *openapi3filter.RequestValidationInput) error {
// since a request validation could fail and ignored due to user options, the input route MAY be nil
if v.requestValidationInput == nil || v.requestValidationInput.Route == nil {
if requestValidationInput == nil || requestValidationInput.Route == nil {
err := fmt.Errorf("response validation: '%s %s': invalid route", beresp.Request.Method, beresp.Request.URL.Path)
if beresp.Request != nil {
if ctx, ok := beresp.Request.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
Expand All @@ -85,7 +83,7 @@ func (v *OpenAPI) ValidateResponse(beresp *http.Response) error {
Body: ioutil.NopCloser(&bytes.Buffer{}),
Header: beresp.Header.Clone(),
Options: v.options.filterOptions,
RequestValidationInput: v.requestValidationInput,
RequestValidationInput: requestValidationInput,
Status: beresp.StatusCode,
}

Expand Down

0 comments on commit c771404

Please sign in to comment.