From ebf86fc27ead028d625a9e229fbef123aa3f7d7b Mon Sep 17 00:00:00 2001 From: yarne Date: Thu, 24 Mar 2022 10:36:49 +0100 Subject: [PATCH 1/3] Add test case to prove that the security schema is validated after all the other schemas --- openapi3filter/validate_request_test.go | 184 ++++++++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 openapi3filter/validate_request_test.go diff --git a/openapi3filter/validate_request_test.go b/openapi3filter/validate_request_test.go new file mode 100644 index 000000000..4302dcfd8 --- /dev/null +++ b/openapi3filter/validate_request_test.go @@ -0,0 +1,184 @@ +package openapi3filter + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/routers" + "github.com/getkin/kin-openapi/routers/gorillamux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTestRouter(t *testing.T, spec string) routers.Router { + t.Helper() + loader := openapi3.NewLoader() + doc, err := loader.LoadFromData([]byte(spec)) + require.NoError(t, err) + + err = doc.Validate(loader.Context) + require.NoError(t, err) + + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err) + + return router +} + +func TestValidateRequest(t *testing.T) { + const spec = ` +openapi: 3.0.0 +info: + title: 'Validator' + version: 0.0.1 +paths: + /category: + post: + parameters: + - name: category + in: query + schema: + type: string + required: true + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - subCategory + properties: + subCategory: + type: string + responses: + '201': + description: Created + security: + - apiKey: [] +components: + securitySchemes: + apiKey: + type: apiKey + name: Api-Key + in: header +` + + router := setupTestRouter(t, spec) + + verifyAPIKeyPresence := func(c context.Context, input *AuthenticationInput) error { + if input.SecurityScheme.Type == "apiKey" { + var found bool + switch input.SecurityScheme.In { + case "query": + _, found = input.RequestValidationInput.GetQueryParams()[input.SecurityScheme.Name] + case "header": + _, found = input.RequestValidationInput.Request.Header[http.CanonicalHeaderKey(input.SecurityScheme.Name)] + case "cookie": + _, err := input.RequestValidationInput.Request.Cookie(input.SecurityScheme.Name) + found = errors.Is(err, http.ErrNoCookie) + } + if !found { + return fmt.Errorf("%v not found in %v", input.SecurityScheme.Name, input.SecurityScheme.In) + } + } + return nil + } + + type testRequestBody struct { + SubCategory string `json:"subCategory"` + } + type args struct { + requestBody *testRequestBody + url string + apiKey string + } + tests := []struct { + name string + args args + expectedErr error + }{ + { + name: "Valid request", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate"}, + url: "/category?category=cookies", + apiKey: "SomeKey", + }, + expectedErr: nil, + }, + { + name: "Invalid operation params", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate"}, + url: "/category?invalidCategory=badCookie", + apiKey: "SomeKey", + }, + expectedErr: &RequestError{}, + }, + { + name: "Invalid request body", + args: args{ + requestBody: nil, + url: "/category?category=cookies", + apiKey: "SomeKey", + }, + expectedErr: &RequestError{}, + }, + { + name: "Invalid security", + args: args{ + requestBody: &testRequestBody{SubCategory: "Chocolate"}, + url: "/category?category=cookies", + apiKey: "", + }, + expectedErr: &SecurityRequirementsError{}, + }, + { + name: "Invalid request body and security", + args: args{ + requestBody: nil, + url: "/category?category=cookies", + apiKey: "", + }, + expectedErr: &SecurityRequirementsError{}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var requestBody io.Reader + if tc.args.requestBody != nil { + testingBody, err := json.Marshal(tc.args.requestBody) + require.NoError(t, err) + requestBody = bytes.NewReader(testingBody) + } + req, err := http.NewRequest(http.MethodPost, tc.args.url, requestBody) + require.NoError(t, err) + req.Header.Add("Content-Type", "application/json") + if tc.args.apiKey != "" { + req.Header.Add("Api-Key", tc.args.apiKey) + } + + route, pathParams, err := router.FindRoute(req) + require.NoError(t, err) + + validationInput := &RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + Options: &Options{ + AuthenticationFunc: verifyAPIKeyPresence, + }, + } + err = ValidateRequest(context.Background(), validationInput) + assert.IsType(t, tc.expectedErr, err, "ValidateRequest(): error = %v, expectedError %v", err, tc.expectedErr) + }) + } +} From 1642932a07525124bc08cb8f17c182aeeda1c350 Mon Sep 17 00:00:00 2001 From: yarne Date: Thu, 24 Mar 2022 10:38:02 +0100 Subject: [PATCH 2/3] Make sure security is validated first --- openapi3filter/validate_request.go | 34 +++++++++++++++--------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/openapi3filter/validate_request.go b/openapi3filter/validate_request.go index 990b299ef..b1bb84fb1 100644 --- a/openapi3filter/validate_request.go +++ b/openapi3filter/validate_request.go @@ -40,6 +40,23 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error { operationParameters := operation.Parameters pathItemParameters := route.PathItem.Parameters + // Security + security := operation.Security + // If there aren't any security requirements for the operation + if security == nil { + // Use the global security requirements. + security = &route.Spec.Security + } + if security != nil { + if err = ValidateSecurityRequirements(ctx, input, *security); err != nil && !options.MultiError { + return err + } + + if err != nil { + me = append(me, err) + } + } + // For each parameter of the PathItem for _, parameterRef := range pathItemParameters { parameter := parameterRef.Value @@ -81,23 +98,6 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error { } } - // Security - security := operation.Security - // If there aren't any security requirements for the operation - if security == nil { - // Use the global security requirements. - security = &route.Spec.Security - } - if security != nil { - if err = ValidateSecurityRequirements(ctx, input, *security); err != nil && !options.MultiError { - return err - } - - if err != nil { - me = append(me, err) - } - } - if len(me) > 0 { return me } From 809c30c1c769d3b54804fc13ca616856f64b9b0c Mon Sep 17 00:00:00 2001 From: Yarn-e Date: Thu, 24 Mar 2022 15:53:26 +0100 Subject: [PATCH 3/3] This is the actual correct check --- openapi3filter/validate_request_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openapi3filter/validate_request_test.go b/openapi3filter/validate_request_test.go index 4302dcfd8..b43f6c813 100644 --- a/openapi3filter/validate_request_test.go +++ b/openapi3filter/validate_request_test.go @@ -83,7 +83,7 @@ components: _, found = input.RequestValidationInput.Request.Header[http.CanonicalHeaderKey(input.SecurityScheme.Name)] case "cookie": _, err := input.RequestValidationInput.Request.Cookie(input.SecurityScheme.Name) - found = errors.Is(err, http.ErrNoCookie) + found = !errors.Is(err, http.ErrNoCookie) } if !found { return fmt.Errorf("%v not found in %v", input.SecurityScheme.Name, input.SecurityScheme.In)