diff --git a/openapi3/schema.go b/openapi3/schema.go index c1730b6ad..6e7691b8d 100644 --- a/openapi3/schema.go +++ b/openapi3/schema.go @@ -1338,6 +1338,19 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value return schema.expectedType(settings, TypeObject) } + if settings.asreq || settings.asrep { + for propName, propSchema := range schema.Properties { + if value[propName] == nil { + if dlft := propSchema.Value.Default; dlft != nil { + value[propName] = dlft + if f := settings.defaultsSet; f != nil { + settings.onceSettingDefaults.Do(f) + } + } + } + } + } + var me MultiError // "properties" diff --git a/openapi3/schema_validation_settings.go b/openapi3/schema_validation_settings.go index 71db5f237..cb4c142a4 100644 --- a/openapi3/schema_validation_settings.go +++ b/openapi3/schema_validation_settings.go @@ -1,5 +1,9 @@ package openapi3 +import ( + "sync" +) + // SchemaValidationOption describes options a user has when validating request / response bodies. type SchemaValidationOption func(*schemaValidationSettings) @@ -7,6 +11,9 @@ type schemaValidationSettings struct { failfast bool multiError bool asreq, asrep bool // exclusive (XOR) fields + + onceSettingDefaults sync.Once + defaultsSet func() } // FailFast returns schema validation errors quicker. @@ -25,6 +32,11 @@ func VisitAsResponse() SchemaValidationOption { return func(s *schemaValidationSettings) { s.asreq, s.asrep = false, true } } +// DefaultsSet executes the given callback (once) IFF schema validation set default values. +func DefaultsSet(f func()) SchemaValidationOption { + return func(s *schemaValidationSettings) { s.defaultsSet = f } +} + func newSchemaValidationSettings(opts ...SchemaValidationOption) *schemaValidationSettings { settings := &schemaValidationSettings{} for _, opt := range opts { diff --git a/openapi3filter/req_resp_decoder.go b/openapi3filter/req_resp_decoder.go index 12b368384..c41b0a01d 100644 --- a/openapi3filter/req_resp_decoder.go +++ b/openapi3filter/req_resp_decoder.go @@ -814,7 +814,11 @@ const prefixUnsupportedCT = "unsupported content type" // decodeBody returns a decoded body. // The function returns ParseError when a body is invalid. -func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn) (interface{}, error) { +func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn) ( + string, + interface{}, + error, +) { contentType := header.Get(headerCT) if contentType == "" { if _, ok := body.(*multipart.Part); ok { @@ -824,16 +828,16 @@ func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, mediaType := parseMediaType(contentType) decoder, ok := bodyDecoders[mediaType] if !ok { - return nil, &ParseError{ + return "", nil, &ParseError{ Kind: KindUnsupportedFormat, Reason: fmt.Sprintf("%s %q", prefixUnsupportedCT, mediaType), } } value, err := decoder(body, header, schema, encFn) if err != nil { - return nil, err + return "", nil, err } - return value, nil + return mediaType, value, nil } func init() { @@ -982,7 +986,7 @@ func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S } var value interface{} - if value, err = decodeBody(part, http.Header(part.Header), valueSchema, subEncFn); err != nil { + if _, value, err = decodeBody(part, http.Header(part.Header), valueSchema, subEncFn); err != nil { if v, ok := err.(*ParseError); ok { return nil, &ParseError{path: []interface{}{name}, Cause: v} } diff --git a/openapi3filter/req_resp_decoder_test.go b/openapi3filter/req_resp_decoder_test.go index 34e63712d..6024ab116 100644 --- a/openapi3filter/req_resp_decoder_test.go +++ b/openapi3filter/req_resp_decoder_test.go @@ -1156,7 +1156,7 @@ func TestDecodeBody(t *testing.T) { } return tc.encoding[name] } - got, err := decodeBody(tc.body, h, schemaRef, encFn) + _, got, err := decodeBody(tc.body, h, schemaRef, encFn) if tc.wantErr != nil { require.Error(t, err) @@ -1226,7 +1226,7 @@ func TestRegisterAndUnregisterBodyDecoder(t *testing.T) { body := strings.NewReader("foo,bar") schema := openapi3.NewArraySchema().WithItems(openapi3.NewStringSchema()).NewRef() encFn := func(string) *openapi3.Encoding { return nil } - got, err := decodeBody(body, h, schema, encFn) + _, got, err := decodeBody(body, h, schema, encFn) require.NoError(t, err) require.Equal(t, []string{"foo", "bar"}, got) @@ -1236,7 +1236,7 @@ func TestRegisterAndUnregisterBodyDecoder(t *testing.T) { originalDecoder = RegisteredBodyDecoder(contentType) require.Nil(t, originalDecoder) - _, err = decodeBody(body, h, schema, encFn) + _, _, err = decodeBody(body, h, schema, encFn) require.Equal(t, &ParseError{ Kind: KindUnsupportedFormat, Reason: prefixUnsupportedCT + ` "text/csv"`, diff --git a/openapi3filter/req_resp_encoder.go b/openapi3filter/req_resp_encoder.go new file mode 100644 index 000000000..b6429d6d8 --- /dev/null +++ b/openapi3filter/req_resp_encoder.go @@ -0,0 +1,27 @@ +package openapi3filter + +import ( + "encoding/json" + "fmt" +) + +func encodeBody(body interface{}, mediaType string) ([]byte, error) { + encoder, ok := bodyEncoders[mediaType] + if !ok { + return nil, &ParseError{ + Kind: KindUnsupportedFormat, + Reason: fmt.Sprintf("%s %q", prefixUnsupportedCT, mediaType), + } + } + return encoder(body) +} + +type bodyEncoder func(body interface{}) ([]byte, error) + +var bodyEncoders = map[string]bodyEncoder{ + "application/json": jsonBodyEncoder, +} + +func jsonBodyEncoder(body interface{}) ([]byte, error) { + return json.Marshal(body) +} diff --git a/openapi3filter/validate_readonly_test.go b/openapi3filter/validate_readonly_test.go index 454a927e9..8b7ccb7ef 100644 --- a/openapi3filter/validate_readonly_test.go +++ b/openapi3filter/validate_readonly_test.go @@ -3,6 +3,7 @@ package openapi3filter import ( "bytes" "encoding/json" + "io/ioutil" "net/http" "testing" @@ -26,6 +27,16 @@ func TestValidatingRequestBodyWithReadOnlyProperty(t *testing.T) { "/accounts": { "post": { "description": "Create a new account", + "parameters": [ + { + "in": "query", + "name": "q", + "schema": { + "type": "string", + "default": "Q" + } + } + ], "requestBody": { "required": true, "content": { @@ -34,6 +45,10 @@ func TestValidatingRequestBodyWithReadOnlyProperty(t *testing.T) { "type": "object", "required": ["_id"], "properties": { + "_": { + "type": "boolean", + "default": false + }, "_id": { "type": "string", "description": "Unique identifier for this object.", @@ -61,10 +76,6 @@ func TestValidatingRequestBodyWithReadOnlyProperty(t *testing.T) { } ` - type Request struct { - ID string `json:"_id"` - } - sl := openapi3.NewLoader() doc, err := sl.LoadFromData([]byte(spec)) require.NoError(t, err) @@ -73,7 +84,12 @@ func TestValidatingRequestBodyWithReadOnlyProperty(t *testing.T) { router, err := legacyrouter.NewRouter(doc) require.NoError(t, err) - b, err := json.Marshal(Request{ID: "bt6kdc3d0cvp6u8u3ft0"}) + b, err := json.Marshal(struct { + Blank bool `json:"_,omitempty"` + ID string `json:"_id"` + }{ + ID: "bt6kdc3d0cvp6u8u3ft0", + }) require.NoError(t, err) httpReq, err := http.NewRequest(http.MethodPost, "/accounts", bytes.NewReader(b)) @@ -89,4 +105,12 @@ func TestValidatingRequestBodyWithReadOnlyProperty(t *testing.T) { Route: route, }) require.NoError(t, err) + + // Unset default values in body were set + validatedReqBody, err := ioutil.ReadAll(httpReq.Body) + require.NoError(t, err) + require.JSONEq(t, `{"_":false,"_id":"bt6kdc3d0cvp6u8u3ft0"}`, string(validatedReqBody)) + // Unset default values in URL were set + // Unset default values in headers were set + // Unset default values in cookies were set } diff --git a/openapi3filter/validate_request.go b/openapi3filter/validate_request.go index 990b299ef..6985dc6df 100644 --- a/openapi3filter/validate_request.go +++ b/openapi3filter/validate_request.go @@ -52,7 +52,6 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error { if err = ValidateParameter(ctx, input, parameter); err != nil && !options.MultiError { return err } - if err != nil { me = append(me, err) } @@ -63,7 +62,6 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error { if err = ValidateParameter(ctx, input, parameter.Value); err != nil && !options.MultiError { return err } - if err != nil { me = append(me, err) } @@ -75,7 +73,6 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error { if err = ValidateRequestBody(ctx, input, requestBody.Value); err != nil && !options.MultiError { return err } - if err != nil { me = append(me, err) } @@ -92,7 +89,6 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) error { if err = ValidateSecurityRequirements(ctx, input, *security); err != nil && !options.MultiError { return err } - if err != nil { me = append(me, err) } @@ -137,6 +133,12 @@ func ValidateParameter(ctx context.Context, input *RequestValidationInput, param } schema = parameter.Schema.Value } + + // // Maybe use default value + // if value == nil && schema != nil { + // value = schema.Default + // } + // Validate a parameter's value. if value == nil { if parameter.Required { @@ -167,16 +169,13 @@ const prefixInvalidCT = "header Content-Type has unexpected value" // The function returns RequestError with ErrInvalidRequired cause when a value is required but not defined. // The function returns RequestError with a openapi3.SchemaError cause when a value is invalid by JSON schema. func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, requestBody *openapi3.RequestBody) error { - var ( - req = input.Request - data []byte - ) - options := input.Options if options == nil { options = DefaultOptions } + var data []byte + req := input.Request if req.Body != http.NoBody && req.Body != nil { defer req.Body.Close() var err error @@ -221,7 +220,7 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req } encFn := func(name string) *openapi3.Encoding { return contentType.Encoding[name] } - value, err := decodeBody(bytes.NewReader(data), req.Header, contentType.Schema, encFn) + mediaType, value, err := decodeBody(bytes.NewReader(data), req.Header, contentType.Schema, encFn) if err != nil { return &RequestError{ Input: input, @@ -231,8 +230,10 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req } } - opts := make([]openapi3.SchemaValidationOption, 0, 2) // 2 potential opts here + defaultsSet := false + opts := make([]openapi3.SchemaValidationOption, 0, 3) // 3 potential opts here opts = append(opts, openapi3.VisitAsRequest()) + opts = append(opts, openapi3.DefaultsSet(func() { defaultsSet = true })) if options.MultiError { opts = append(opts, openapi3.MultiErrors()) } @@ -246,6 +247,21 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req Err: err, } } + + if defaultsSet { + var err error + if data, err = encodeBody(value, mediaType); err != nil { + return &RequestError{ + Input: input, + RequestBody: requestBody, + Reason: "rewriting failed", + Err: err, + } + } + // Put the data back into the input + req.Body = ioutil.NopCloser(bytes.NewReader(data)) + } + return nil } diff --git a/openapi3filter/validate_response.go b/openapi3filter/validate_response.go index 7cb713ace..f19123e53 100644 --- a/openapi3filter/validate_response.go +++ b/openapi3filter/validate_response.go @@ -111,7 +111,7 @@ func ValidateResponse(ctx context.Context, input *ResponseValidationInput) error input.SetBodyBytes(data) encFn := func(name string) *openapi3.Encoding { return contentType.Encoding[name] } - value, err := decodeBody(bytes.NewBuffer(data), input.Header, contentType.Schema, encFn) + _, value, err := decodeBody(bytes.NewBuffer(data), input.Header, contentType.Schema, encFn) if err != nil { return &ResponseError{ Input: input,