Skip to content

Commit

Permalink
openAPIValidate function in Backend; store requestValidationInput in …
Browse files Browse the repository at this point in the history
…local variable
  • Loading branch information
Johannes Koch authored and Marcel Ludwig committed May 17, 2021
1 parent abd9aa3 commit 2ef4dff
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 39 deletions.
68 changes: 44 additions & 24 deletions handler/transport/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, err
}

t := Get(tc)

deadlineErr := b.withTimeout(req)

req.URL.Host = tc.Origin
Expand All @@ -105,13 +103,6 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Del(AcceptEncodingHeader)
}

if b.openAPIValidator != nil {
// FIXME tc.Origin should be an origin, not just a host!
if err = b.openAPIValidator.ValidateRequest(req, tc.hash(), tc.Scheme+"://"+tc.Origin); err != nil {
return nil, errors.BackendValidation.Label(b.name).Kind("backend_request_validation").With(err)
}
}

if xff, ok := req.Context().Value(request.XFF).(string); ok {
if xff != "" {
req.Header.Set("X-Forwarded-For", xff)
Expand All @@ -129,23 +120,15 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := context.WithValue(req.Context(), request.BackendURL, req.URL.String())
*req = *req.WithContext(ctx)

beresp, err := t.RoundTrip(req)
if err != nil {
select {
case derr := <-deadlineErr:
if derr != nil {
return nil, derr
}
default:
return nil, errors.Backend.Label(b.name).With(err)
}
var beresp *http.Response
if b.openAPIValidator != nil {
beresp, err = b.openAPIValidate(req, tc, deadlineErr)
} else {
beresp, err = b.innerRoundTrip(req, tc, deadlineErr)
}

if b.openAPIValidator != nil {
if err = b.openAPIValidator.ValidateResponse(beresp); err != nil {
return nil, errors.BackendValidation.Label(b.name).Kind("backend_response_validation").
With(err).Status(http.StatusBadGateway)
}
if err != nil {
return nil, err
}

if strings.ToLower(beresp.Header.Get(ContentEncodingHeader)) == GzipName {
Expand Down Expand Up @@ -173,6 +156,43 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
return beresp, err
}

func (b *Backend) openAPIValidate(req *http.Request, tc *Config, deadlineErr <-chan error) (*http.Response, error) {
requestValidationInput, err := b.openAPIValidator.ValidateRequest(req, tc.hash(), tc.Scheme+"://"+tc.Origin)
if err != nil {
return nil, errors.BackendValidation.Label(b.name).Kind("backend_request_validation").With(err)
}

beresp, err := b.innerRoundTrip(req, tc, deadlineErr)
if err != nil {
return nil, err
}

if err = b.openAPIValidator.ValidateResponse(beresp, requestValidationInput); err != nil {
return nil, errors.BackendValidation.Label(b.name).Kind("backend_response_validation").
With(err).Status(http.StatusBadGateway)
}

return beresp, nil
}

func (b *Backend) innerRoundTrip(req *http.Request, tc *Config, deadlineErr <-chan error) (*http.Response, error) {
t := Get(tc)
beresp, err := t.RoundTrip(req)

if err != nil {
select {
case derr := <-deadlineErr:
if derr != nil {
return nil, derr
}
default:
return nil, errors.Backend.Label(b.name).With(err)
}
}

return beresp, nil
}

func (b *Backend) withPathPrefix(req *http.Request) {
if pathPrefix := b.getAttribute(req, "path_prefix"); pathPrefix != "" {
req.URL.Path = utils.JoinPath("/", pathPrefix, req.URL.Path)
Expand Down
29 changes: 14 additions & 15 deletions handler/validation/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ import (
var swaggers sync.Map

type OpenAPI struct {
options *OpenAPIOptions
requestValidationInput *openapi3filter.RequestValidationInput
options *OpenAPIOptions
}

func NewOpenAPI(opts *OpenAPIOptions) *OpenAPI {
Expand Down Expand Up @@ -69,21 +68,21 @@ func cloneSwagger(s *openapi3.Swagger) *openapi3.Swagger {
return &sw
}

func (v *OpenAPI) ValidateRequest(req *http.Request, key, origin string) error {
func (v *OpenAPI) ValidateRequest(req *http.Request, key, origin string) (*openapi3filter.RequestValidationInput, error) {
swagger, err := v.getModifiedSwagger(key, origin)
if err != nil {
if ctx, ok := req.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
ctx.errors = append(ctx.errors, err)
}
if !v.options.ignoreRequestViolations {
return err
return nil, err
}
return nil
return nil, nil
}

router := openapi3filter.NewRouter()
if err = router.AddSwagger(swagger); err != nil {
return err
return nil, err
}

route, pathParams, err := router.FindRoute(req.Method, req.URL)
Expand All @@ -93,12 +92,12 @@ func (v *OpenAPI) ValidateRequest(req *http.Request, key, origin string) error {
ctx.errors = append(ctx.errors, err)
}
if !v.options.ignoreRequestViolations {
return err
return nil, err
}
return nil
return nil, nil
}

v.requestValidationInput = &openapi3filter.RequestValidationInput{
requestValidationInput := &openapi3filter.RequestValidationInput{
Options: v.options.filterOptions,
PathParams: pathParams,
QueryParams: req.URL.Query(),
Expand All @@ -107,21 +106,21 @@ func (v *OpenAPI) ValidateRequest(req *http.Request, key, origin string) error {
}

// openapi3filter.ValidateRequestBody also handles resetting the req body after reading until EOF.
if err = openapi3filter.ValidateRequest(req.Context(), v.requestValidationInput); err != nil {
if err = openapi3filter.ValidateRequest(req.Context(), requestValidationInput); err != nil {
if ctx, ok := req.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
ctx.errors = append(ctx.errors, err)
}
if !v.options.ignoreRequestViolations {
return err
return requestValidationInput, err
}
}

return nil
return requestValidationInput, nil
}

func (v *OpenAPI) ValidateResponse(beresp *http.Response) error {
func (v *OpenAPI) ValidateResponse(beresp *http.Response, requestValidationInput *openapi3filter.RequestValidationInput) error {
// since a request validation could fail and ignored due to user options, the input route MAY be nil
if v.requestValidationInput == nil || v.requestValidationInput.Route == nil {
if requestValidationInput == nil || requestValidationInput.Route == nil {
err := fmt.Errorf("'%s %s': invalid route", beresp.Request.Method, beresp.Request.URL.Path)
if beresp.Request != nil {
if ctx, ok := beresp.Request.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
Expand All @@ -139,7 +138,7 @@ func (v *OpenAPI) ValidateResponse(beresp *http.Response) error {
Body: ioutil.NopCloser(&bytes.Buffer{}),
Header: beresp.Header.Clone(),
Options: v.options.filterOptions,
RequestValidationInput: v.requestValidationInput,
RequestValidationInput: requestValidationInput,
Status: beresp.StatusCode,
}

Expand Down

0 comments on commit 2ef4dff

Please sign in to comment.