diff --git a/openapi3filter/middleware.go b/openapi3filter/middleware.go new file mode 100644 index 000000000..3709faf9b --- /dev/null +++ b/openapi3filter/middleware.go @@ -0,0 +1,273 @@ +package openapi3filter + +import ( + "bytes" + "io" + "io/ioutil" + "log" + "net/http" + + "github.com/getkin/kin-openapi/routers" +) + +// Validator provides HTTP request and response validation middleware. +type Validator struct { + router routers.Router + errFunc ErrFunc + logFunc LogFunc + strict bool +} + +// ErrFunc handles errors that may occur during validation. +type ErrFunc func(w http.ResponseWriter, status int, code ErrCode, err error) + +// LogFunc handles log messages that may occur during validation. +type LogFunc func(message string, err error) + +// ErrCode is used for classification of different types of errors that may +// occur during validation. These may be used to write an appropriate response +// in ErrFunc. +type ErrCode int + +const ( + // ErrCodeOK indicates no error. It is also the default value. + ErrCodeOK = 0 + // ErrCodeCannotFindRoute happens when the validator fails to resolve the + // request to a defined OpenAPI route. + ErrCodeCannotFindRoute = iota + // ErrCodeRequestInvalid happens when the inbound request does not conform + // to the OpenAPI 3 specification. + ErrCodeRequestInvalid = iota + // ErrCodeResponseInvalid happens when the wrapped handler response does + // not conform to the OpenAPI 3 specification. + ErrCodeResponseInvalid = iota +) + +func (e ErrCode) responseText() string { + switch e { + case ErrCodeOK: + return "OK" + case ErrCodeCannotFindRoute: + return "not found" + case ErrCodeRequestInvalid: + return "bad request" + default: + return "server error" + } +} + +// NewValidator returns a new response validation middlware, using the given +// routes from an OpenAPI 3 specification. +func NewValidator(router routers.Router, options ...ValidatorOption) *Validator { + v := &Validator{ + router: router, + errFunc: func(w http.ResponseWriter, status int, code ErrCode, _ error) { + http.Error(w, code.responseText(), status) + }, + logFunc: func(message string, err error) { + log.Printf("%s: %v", message, err) + }, + } + for i := range options { + options[i](v) + } + return v +} + +// ValidatorOption defines an option that may be specified when creating a +// Validator. +type ValidatorOption func(*Validator) + +// OnErr provides a callback that handles writing an HTTP response on a +// validation error. This allows customization of error responses without +// prescribing a particular form. This callback is only called on response +// validator errors in Strict mode. +func OnErr(f ErrFunc) ValidatorOption { + return func(v *Validator) { + v.errFunc = f + } +} + +// OnLog provides a callback that handles logging in the Validator. This allows +// the validator to integrate with a services' existing logging system without +// prescribing a particular one. +func OnLog(f LogFunc) ValidatorOption { + return func(v *Validator) { + v.logFunc = f + } +} + +// Strict, if set, causes an internal server error to be sent if the wrapped +// handler response fails response validation. If not set, the response is sent +// and the error is only logged. +func Strict(strict bool) ValidatorOption { + return func(v *Validator) { + v.strict = strict + } +} + +// Middleware returns an http.Handler which wraps the given handler with +// request and response validation. +func (v *Validator) Middleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + route, pathParams, err := v.router.FindRoute(r) + if err != nil { + v.logFunc("validation error: failed to find route for "+r.URL.String(), err) + v.errFunc(w, http.StatusNotFound, ErrCodeCannotFindRoute, err) + return + } + requestValidationInput := &RequestValidationInput{ + Request: r, + PathParams: pathParams, + Route: route, + } + if err = ValidateRequest(r.Context(), requestValidationInput); err != nil { + v.logFunc("invalid request", err) + v.errFunc(w, http.StatusBadRequest, ErrCodeRequestInvalid, err) + return + } + + var wr responseWrapper + if v.strict { + wr = &strictResponseWrapper{w: w} + } else { + wr = newWarnResponseWrapper(w) + } + + h.ServeHTTP(wr, r) + + if err = ValidateResponse(r.Context(), &ResponseValidationInput{ + RequestValidationInput: requestValidationInput, + Status: wr.statusCode(), + Header: wr.Header(), + Body: ioutil.NopCloser(bytes.NewBuffer(wr.bodyContents())), + }); err != nil { + v.logFunc("invalid response", err) + if v.strict { + v.errFunc(w, http.StatusInternalServerError, ErrCodeResponseInvalid, err) + } + return + } + + if err = wr.flushBodyContents(); err != nil { + v.logFunc("failed to write response", err) + } + }) +} + +type responseWrapper interface { + http.ResponseWriter + + // flushBodyContents writes the buffered response to the client, if it has + // not yet been written. + flushBodyContents() error + + // statusCode returns the response status code, 0 if not set yet. + statusCode() int + + // bodyContents returns the buffered + bodyContents() []byte +} + +type warnResponseWrapper struct { + w http.ResponseWriter + headerWritten bool + status int + body bytes.Buffer + tee io.Writer +} + +func newWarnResponseWrapper(w http.ResponseWriter) *warnResponseWrapper { + wr := &warnResponseWrapper{ + w: w, + } + wr.tee = io.MultiWriter(w, &wr.body) + return wr +} + +// Write implements http.ResponseWriter. +func (wr *warnResponseWrapper) Write(b []byte) (int, error) { + if !wr.headerWritten { + wr.WriteHeader(http.StatusOK) + } + return wr.tee.Write(b) +} + +// WriteHeader implements http.ResponseWriter. +func (wr *warnResponseWrapper) WriteHeader(status int) { + if !wr.headerWritten { + // If the header hasn't been written, record the status for response + // validation. + wr.status = status + wr.headerWritten = true + } + wr.w.WriteHeader(wr.status) +} + +// Header implements http.ResponseWriter. +func (wr *warnResponseWrapper) Header() http.Header { + return wr.w.Header() +} + +// Flush implements the optional http.Flusher interface. +func (wr *warnResponseWrapper) Flush() { + // If the wrapped http.ResponseWriter implements optional http.Flusher, + // pass through. + if fl, ok := wr.w.(http.Flusher); ok { + fl.Flush() + } +} + +func (wr *warnResponseWrapper) flushBodyContents() error { + return nil +} + +func (wr *warnResponseWrapper) statusCode() int { + return wr.status +} + +func (wr *warnResponseWrapper) bodyContents() []byte { + return wr.body.Bytes() +} + +type strictResponseWrapper struct { + w http.ResponseWriter + headerWritten bool + status int + body bytes.Buffer +} + +// Write implements http.ResponseWriter. +func (wr *strictResponseWrapper) Write(b []byte) (int, error) { + if !wr.headerWritten { + wr.WriteHeader(http.StatusOK) + } + return wr.body.Write(b) +} + +// WriteHeader implements http.ResponseWriter. +func (wr *strictResponseWrapper) WriteHeader(status int) { + if !wr.headerWritten { + wr.status = status + wr.headerWritten = true + } +} + +// Header implements http.ResponseWriter. +func (wr *strictResponseWrapper) Header() http.Header { + return wr.w.Header() +} + +func (wr *strictResponseWrapper) flushBodyContents() error { + wr.w.WriteHeader(wr.status) + _, err := wr.w.Write(wr.body.Bytes()) + return err +} + +func (wr *strictResponseWrapper) statusCode() int { + return wr.status +} + +func (wr *strictResponseWrapper) bodyContents() []byte { + return wr.body.Bytes() +} diff --git a/openapi3filter/middleware_test.go b/openapi3filter/middleware_test.go new file mode 100644 index 000000000..45869fb44 --- /dev/null +++ b/openapi3filter/middleware_test.go @@ -0,0 +1,500 @@ +package openapi3filter_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "path" + "regexp" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers/gorillamux" +) + +const validatorSpec = ` +openapi: 3.0.0 +info: + title: 'Validator' + version: '0.0.0' +paths: + /test: + post: + operationId: newTest + description: create a new test + parameters: + - in: query + name: version + schema: + type: string + required: true + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/TestContents' } + responses: + '201': + description: 'created test' + content: + application/json: + schema: { $ref: '#/components/schemas/TestResource' } + '400': { $ref: '#/components/responses/ErrorResponse' } + '500': { $ref: '#/components/responses/ErrorResponse' } + /test/{id}: + get: + operationId: getTest + description: get a test + parameters: + - in: path + name: id + schema: + type: string + required: true + - in: query + name: version + schema: + type: string + required: true + responses: + '200': + description: 'respond with test resource' + content: + application/json: + schema: { $ref: '#/components/schemas/TestResource' } + '400': { $ref: '#/components/responses/ErrorResponse' } + '404': { $ref: '#/components/responses/ErrorResponse' } + '500': { $ref: '#/components/responses/ErrorResponse' } +components: + schemas: + TestContents: + type: object + properties: + name: + type: string + expected: + type: number + actual: + type: number + required: [name, expected, actual] + additionalProperties: false + TestResource: + type: object + properties: + id: + type: string + contents: + { $ref: '#/components/schemas/TestContents' } + required: [id, contents] + additionalProperties: false + Error: + type: object + properties: + code: + type: string + message: + type: string + required: [code, message] + additionalProperties: false + responses: + ErrorResponse: + description: 'an error occurred' + content: + application/json: + schema: { $ref: '#/components/schemas/Error' } +` + +type validatorTestHandler struct { + contentType string + getBody, postBody string + errBody string + errStatusCode int +} + +const validatorOkResponse = `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}}` + +func (h validatorTestHandler) withDefaults() validatorTestHandler { + if h.contentType == "" { + h.contentType = "application/json" + } + if h.getBody == "" { + h.getBody = validatorOkResponse + } + if h.postBody == "" { + h.postBody = validatorOkResponse + } + if h.errBody == "" { + h.errBody = `{"code":"bad","message":"bad things"}` + } + return h +} + +var testUrlRE = regexp.MustCompile(`^/test(/\d+)?$`) + +func (h *validatorTestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", h.contentType) + if h.errStatusCode != 0 { + w.WriteHeader(h.errStatusCode) + w.Write([]byte(h.errBody)) + return + } + if !testUrlRE.MatchString(r.URL.Path) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(h.errBody)) + return + } + switch r.Method { + case "GET": + w.WriteHeader(http.StatusOK) + w.Write([]byte(h.getBody)) + case "POST": + w.WriteHeader(http.StatusCreated) + w.Write([]byte(h.postBody)) + default: + http.Error(w, h.errBody, http.StatusMethodNotAllowed) + } +} + +func TestValidator(t *testing.T) { + doc, err := openapi3.NewLoader().LoadFromData([]byte(validatorSpec)) + require.NoError(t, err, "failed to load test fixture spec") + + ctx := context.Background() + err = doc.Validate(ctx) + require.NoError(t, err, "invalid test fixture spec") + + type testRequest struct { + method, path, body, contentType string + } + type testResponse struct { + statusCode int + body string + } + tests := []struct { + name string + handler validatorTestHandler + options []openapi3filter.ValidatorOption + request testRequest + response testResponse + strict bool + }{{ + name: "valid GET", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test/42?version=1", + }, + response: testResponse{ + 200, validatorOkResponse, + }, + strict: true, + }, { + name: "valid POST", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=1", + body: `{"name": "foo", "expected": 9, "actual": 10}`, + contentType: "application/json", + }, + response: testResponse{ + 201, validatorOkResponse, + }, + strict: true, + }, { + name: "not found; no GET operation for /test", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test?version=1", + }, + response: testResponse{ + 404, "not found\n", + }, + strict: true, + }, { + name: "not found; no POST operation for /test/42", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test/42?version=1", + }, + response: testResponse{ + 404, "not found\n", + }, + strict: true, + }, { + name: "invalid request; missing version", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test/42", + }, + response: testResponse{ + 400, "bad request\n", + }, + strict: true, + }, { + name: "invalid POST request; wrong property type", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=1", + body: `{"name": "foo", "expected": "nine", "actual": "ten"}`, + contentType: "application/json", + }, + response: testResponse{ + 400, "bad request\n", + }, + strict: true, + }, { + name: "invalid POST request; missing property", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=1", + body: `{"name": "foo", "expected": 9}`, + contentType: "application/json", + }, + response: testResponse{ + 400, "bad request\n", + }, + strict: true, + }, { + name: "invalid POST request; extra property", + handler: validatorTestHandler{}.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=1", + body: `{"name": "foo", "expected": 9, "actual": 10, "ideal": 8}`, + contentType: "application/json", + }, + response: testResponse{ + 400, "bad request\n", + }, + strict: true, + }, { + name: "valid response; 404 error", + handler: validatorTestHandler{ + contentType: "application/json", + errBody: `{"code": "404", "message": "not found"}`, + errStatusCode: 404, + }.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test/42?version=1", + }, + response: testResponse{ + 404, `{"code": "404", "message": "not found"}`, + }, + strict: true, + }, { + name: "invalid response; invalid error", + handler: validatorTestHandler{ + errBody: `"not found"`, + errStatusCode: 404, + }.withDefaults(), + request: testRequest{ + method: "GET", + path: "/test/42?version=1", + }, + response: testResponse{ + 500, "server error\n", + }, + strict: true, + }, { + name: "invalid POST response; not strict", + handler: validatorTestHandler{ + postBody: `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}, "extra": true}`, + }.withDefaults(), + request: testRequest{ + method: "POST", + path: "/test?version=1", + body: `{"name": "foo", "expected": 9, "actual": 10}`, + contentType: "application/json", + }, + response: testResponse{ + statusCode: 201, + body: `{"id": "42", "contents": {"name": "foo", "expected": 9, "actual": 10}, "extra": true}`, + }, + strict: false, + }} + for i, test := range tests { + t.Logf("test#%d: %s", i, test.name) + t.Run(test.name, func(t *testing.T) { + // Set up a test HTTP server + var h http.Handler + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + })) + defer s.Close() + + // Update the OpenAPI servers section with the test server URL This is + // needed by the router which matches request routes for OpenAPI + // validation. + doc.Servers = []*openapi3.Server{{URL: s.URL}} + err = doc.Validate(ctx) + require.NoError(t, err, "failed to validate with test server") + + // Create the router and validator + router, err := gorillamux.NewRouter(doc) + require.NoError(t, err, "failed to create router") + + // Now wrap the test handler with the validator middlware + v := openapi3filter.NewValidator(router, append(test.options, openapi3filter.Strict(test.strict))...) + h = v.Middleware(&test.handler) + + // Test: make a client request + var requestBody io.Reader + if test.request.body != "" { + requestBody = bytes.NewBufferString(test.request.body) + } + req, err := http.NewRequest(test.request.method, s.URL+test.request.path, requestBody) + require.NoError(t, err, "failed to create request") + + if test.request.contentType != "" { + req.Header.Set("Content-Type", test.request.contentType) + } + resp, err := s.Client().Do(req) + require.NoError(t, err, "request failed") + defer resp.Body.Close() + require.Equalf(t, test.response.statusCode, resp.StatusCode, + "response code expect %d got %d", test.response.statusCode, resp.StatusCode) + + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err, "failed to read response body") + require.Equalf(t, test.response.body, string(body), + "response body expect %q got %q", test.response.body, string(body)) + }) + } +} + +func ExampleValidator() { + // OpenAPI specification for a simple service that squares integers, with + // some limitations. + doc, err := openapi3.NewLoader().LoadFromData([]byte(` +info: +openapi: 3.0.0 +info: + title: 'Validator - square example' + version: '0.0.0' +paths: + /square/{x}: + get: + description: square an integer + parameters: + - name: x + in: path + schema: + type: integer + required: true + responses: + '200': + description: squared integer response + content: + "application/json": + schema: + type: object + properties: + result: + type: integer + minimum: 0 + maximum: 1000000 + required: [result] + additionalProperties: false`[1:])) + if err != nil { + panic(err) + } + + // Square service handler sanity checks inputs, but just crashes on invalid + // requests. + squareHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + xParam := path.Base(r.URL.Path) + x, err := strconv.Atoi(xParam) + if err != nil { + panic(err) + } + w.Header().Set("Content-Type", "application/json") + result := map[string]interface{}{"result": x * x} + if x == 42 { + // An easter egg. Unfortunately, the spec does not allow additional properties... + result["comment"] = "the answer to the ulitimate question of life, the universe, and everything" + } + if err = json.NewEncoder(w).Encode(&result); err != nil { + panic(err) + } + }) + + // Start an http server. + var mainHandler http.Handler + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Why are we wrapping the main server handler with a closure here? + // Validation matches request Host: to server URLs in the spec. With an + // httptest.Server, the URL is dynamic and we have to create it first! + // In a real configured service, this is less likely to be needed. + mainHandler.ServeHTTP(w, r) + })) + defer srv.Close() + + // Patch the OpenAPI spec to match the httptest.Server.URL. Only needed + // because the server URL is dynamic here. + doc.Servers = []*openapi3.Server{{URL: srv.URL}} + if err := doc.Validate(context.Background()); err != nil { // Assert our OpenAPI is valid! + panic(err) + } + // This router is used by the validator to match requests with the OpenAPI + // spec. It does not place restrictions on how the wrapped handler routes + // requests; use of gorilla/mux is just a validator implementation detail. + router, err := gorillamux.NewRouter(doc) + if err != nil { + panic(err) + } + // Strict validation will respond HTTP 500 if the service tries to emit a + // response that does not conform to the OpenAPI spec. Very useful for + // testing a service against its spec in development and CI. In production, + // availability may be more important than strictness. + v := openapi3filter.NewValidator(router, openapi3filter.Strict(true)) + // Now we can finally set the main server handler. + mainHandler = v.Middleware(squareHandler) + + printResp := func(resp *http.Response, err error) { + defer resp.Body.Close() + if err != nil { + panic(err) + } + contents, err := ioutil.ReadAll(resp.Body) + if err != nil { + panic(err) + } + fmt.Println(resp.StatusCode, strings.TrimSpace(string(contents))) + } + // Valid requests to our sum service + printResp(srv.Client().Get(srv.URL + "/square/2")) + printResp(srv.Client().Get(srv.URL + "/square/789")) + // Not found requests - method or path not found + printResp(srv.Client().Post(srv.URL+"/square/2", "application/json", bytes.NewBufferString(`{"result": 5}`))) + printResp(srv.Client().Get(srv.URL + "/sum/2")) + // Bad requests - note they never reach the wrapped square handler (which would panic) + printResp(srv.Client().Get(srv.URL + "/square/five")) + // Invalid responses + printResp(srv.Client().Get(srv.URL + "/square/42")) // Our "easter egg" added a property which is not allowed + printResp(srv.Client().Get(srv.URL + "/square/65536")) // Answer overflows the maximum allowed value (1000000) + // Output: + // 200 {"result":4} + // 200 {"result":622521} + // 404 not found + // 404 not found + // 400 bad request + // 500 server error + // 500 server error +}