Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

race condition OpenAPI requestValidationInput #231

Merged
merged 2 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions handler/transport/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, err
}

t := Get(tc)

deadlineErr := b.withTimeout(req)

req.URL.Host = tc.Origin
Expand All @@ -105,13 +103,6 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Del(AcceptEncodingHeader)
}

if b.openAPIValidator != nil {
// FIXME tc.Origin should be an origin, not just a host!
if err = b.openAPIValidator.ValidateRequest(req, tc.hash(), tc.Scheme+"://"+tc.Origin); err != nil {
return nil, errors.BackendValidation.Label(b.name).Kind("backend_request_validation").With(err)
}
}

if xff, ok := req.Context().Value(request.XFF).(string); ok {
if xff != "" {
req.Header.Set("X-Forwarded-For", xff)
Expand All @@ -129,23 +120,15 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := context.WithValue(req.Context(), request.BackendURL, req.URL.String())
*req = *req.WithContext(ctx)

beresp, err := t.RoundTrip(req)
if err != nil {
select {
case derr := <-deadlineErr:
if derr != nil {
return nil, derr
}
default:
return nil, errors.Backend.Label(b.name).With(err)
}
var beresp *http.Response
if b.openAPIValidator != nil {
beresp, err = b.openAPIValidate(req, tc, deadlineErr)
} else {
beresp, err = b.innerRoundTrip(req, tc, deadlineErr)
}

if b.openAPIValidator != nil {
if err = b.openAPIValidator.ValidateResponse(beresp); err != nil {
return nil, errors.BackendValidation.Label(b.name).Kind("backend_response_validation").
With(err).Status(http.StatusBadGateway)
}
if err != nil {
return nil, err
}

if strings.ToLower(beresp.Header.Get(ContentEncodingHeader)) == GzipName {
Expand Down Expand Up @@ -173,6 +156,43 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
return beresp, err
}

func (b *Backend) openAPIValidate(req *http.Request, tc *Config, deadlineErr <-chan error) (*http.Response, error) {
requestValidationInput, err := b.openAPIValidator.ValidateRequest(req, tc.hash(), tc.Scheme+"://"+tc.Origin)
if err != nil {
return nil, errors.BackendValidation.Label(b.name).Kind("backend_request_validation").With(err)
}

beresp, err := b.innerRoundTrip(req, tc, deadlineErr)
if err != nil {
return nil, err
}

if err = b.openAPIValidator.ValidateResponse(beresp, requestValidationInput); err != nil {
return nil, errors.BackendValidation.Label(b.name).Kind("backend_response_validation").
With(err).Status(http.StatusBadGateway)
}

return beresp, nil
}

func (b *Backend) innerRoundTrip(req *http.Request, tc *Config, deadlineErr <-chan error) (*http.Response, error) {
t := Get(tc)
beresp, err := t.RoundTrip(req)

if err != nil {
select {
case derr := <-deadlineErr:
if derr != nil {
return nil, derr
}
default:
return nil, errors.Backend.Label(b.name).With(err)
}
}

return beresp, nil
}

func (b *Backend) withPathPrefix(req *http.Request) {
if pathPrefix := b.getAttribute(req, "path_prefix"); pathPrefix != "" {
req.URL.Path = utils.JoinPath("/", pathPrefix, req.URL.Path)
Expand Down
29 changes: 14 additions & 15 deletions handler/validation/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ import (
var swaggers sync.Map

type OpenAPI struct {
options *OpenAPIOptions
requestValidationInput *openapi3filter.RequestValidationInput
options *OpenAPIOptions
}

func NewOpenAPI(opts *OpenAPIOptions) *OpenAPI {
Expand Down Expand Up @@ -69,21 +68,21 @@ func cloneSwagger(s *openapi3.Swagger) *openapi3.Swagger {
return &sw
}

func (v *OpenAPI) ValidateRequest(req *http.Request, key, origin string) error {
func (v *OpenAPI) ValidateRequest(req *http.Request, key, origin string) (*openapi3filter.RequestValidationInput, error) {
swagger, err := v.getModifiedSwagger(key, origin)
if err != nil {
if ctx, ok := req.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
ctx.errors = append(ctx.errors, err)
}
if !v.options.ignoreRequestViolations {
return err
return nil, err
}
return nil
return nil, nil
}

router := openapi3filter.NewRouter()
if err = router.AddSwagger(swagger); err != nil {
return err
return nil, err
}

route, pathParams, err := router.FindRoute(req.Method, req.URL)
Expand All @@ -93,12 +92,12 @@ func (v *OpenAPI) ValidateRequest(req *http.Request, key, origin string) error {
ctx.errors = append(ctx.errors, err)
}
if !v.options.ignoreRequestViolations {
return err
return nil, err
}
return nil
return nil, nil
}

v.requestValidationInput = &openapi3filter.RequestValidationInput{
requestValidationInput := &openapi3filter.RequestValidationInput{
Options: v.options.filterOptions,
PathParams: pathParams,
QueryParams: req.URL.Query(),
Expand All @@ -107,21 +106,21 @@ func (v *OpenAPI) ValidateRequest(req *http.Request, key, origin string) error {
}

// openapi3filter.ValidateRequestBody also handles resetting the req body after reading until EOF.
if err = openapi3filter.ValidateRequest(req.Context(), v.requestValidationInput); err != nil {
if err = openapi3filter.ValidateRequest(req.Context(), requestValidationInput); err != nil {
if ctx, ok := req.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
ctx.errors = append(ctx.errors, err)
}
if !v.options.ignoreRequestViolations {
return err
return requestValidationInput, err
}
}

return nil
return requestValidationInput, nil
}

func (v *OpenAPI) ValidateResponse(beresp *http.Response) error {
func (v *OpenAPI) ValidateResponse(beresp *http.Response, requestValidationInput *openapi3filter.RequestValidationInput) error {
// since a request validation could fail and ignored due to user options, the input route MAY be nil
if v.requestValidationInput == nil || v.requestValidationInput.Route == nil {
if requestValidationInput == nil || requestValidationInput.Route == nil {
err := fmt.Errorf("'%s %s': invalid route", beresp.Request.Method, beresp.Request.URL.Path)
if beresp.Request != nil {
if ctx, ok := beresp.Request.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
Expand All @@ -139,7 +138,7 @@ func (v *OpenAPI) ValidateResponse(beresp *http.Response) error {
Body: ioutil.NopCloser(&bytes.Buffer{}),
Header: beresp.Header.Clone(),
Options: v.options.filterOptions,
RequestValidationInput: v.requestValidationInput,
RequestValidationInput: requestValidationInput,
Status: beresp.StatusCode,
}

Expand Down
43 changes: 43 additions & 0 deletions server/http_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -1875,6 +1876,48 @@ func TestHTTPServer_Endpoint_Evaluation_Inheritance_Backend_Block(t *testing.T)
}
}

func TestOpenAPIValidateConcurrentRequests(t *testing.T) {
helper := test.New(t)
client := newClient()

shutdown, _ := newCouper("testdata/integration/validation/01_couper.hcl", test.New(t))
defer shutdown()

req1, err := http.NewRequest(http.MethodGet, "http://example.com:8080/anything", nil)
helper.Must(err)
req2, err := http.NewRequest(http.MethodGet, "http://example.com:8080/pdf", nil)
helper.Must(err)

var res1, res2 *http.Response
var err1, err2 error
waitCh := make(chan struct{})
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
<-waitCh // blocks
res1, err1 = client.Do(req1)
}()
go func() {
defer wg.Done()
<-waitCh // blocks
res2, err2 = client.Do(req2)
}()

close(waitCh) // triggers reqs
wg.Wait()

helper.Must(err1)
helper.Must(err2)

if res1.StatusCode != 200 {
t.Errorf("Expected status %d for response1; got: %d", 200, res1.StatusCode)
}
if res2.StatusCode != 502 {
t.Errorf("Expected status %d for response2; got: %d", 502, res2.StatusCode)
}
}

func TestConfigBodyContent(t *testing.T) {
helper := test.New(t)
client := newClient()
Expand Down
18 changes: 18 additions & 0 deletions server/testdata/integration/validation/01_couper.hcl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
server "concurrent-requests" {
api {
endpoint "/**" {
proxy {
backend = "test-be"
}
}
}
}

definitions {
backend "test-be" {
origin = env.COUPER_TEST_BACKEND_ADDR
openapi {
file = "01_schema.yaml"
}
}
}
17 changes: 17 additions & 0 deletions server/testdata/integration/validation/01_schema.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
openapi: '3'
info:
title: 'Couper backend validation test'
version: 'v1.2.3'
paths:
/anything:
get:
responses:
200:
description: OK
schema:
type: object
/pdf:
get:
responses:
555:
description: OK