Skip to content

Commit

Permalink
feat: handling default in request body and parameter schema (#544)
Browse files Browse the repository at this point in the history
* wip setting defaults for #206

Signed-off-by: Pierre Fenoll <pierrefenoll@gmail.com>

* introduce body encoders

Signed-off-by: Pierre Fenoll <pierrefenoll@gmail.com>

* re-encode only when needed

Signed-off-by: Pierre Fenoll <pierrefenoll@gmail.com>

* set default for parameter and add more test cases

Co-authored-by: Pierre Fenoll <pierrefenoll@gmail.com>
  • Loading branch information
nic-6443 and fenollp authored May 30, 2022
1 parent 121fc06 commit 7f8f768
Show file tree
Hide file tree
Showing 8 changed files with 669 additions and 11 deletions.
13 changes: 13 additions & 0 deletions openapi3/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,19 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value
return schema.expectedType(settings, TypeObject)
}

if settings.asreq || settings.asrep {
for propName, propSchema := range schema.Properties {
if value[propName] == nil {
if dlft := propSchema.Value.Default; dlft != nil {
value[propName] = dlft
if f := settings.defaultsSet; f != nil {
settings.onceSettingDefaults.Do(f)
}
}
}
}
}

var me MultiError

// "properties"
Expand Down
12 changes: 12 additions & 0 deletions openapi3/schema_validation_settings.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package openapi3

import (
"sync"
)

// SchemaValidationOption describes options a user has when validating request / response bodies.
type SchemaValidationOption func(*schemaValidationSettings)

type schemaValidationSettings struct {
failfast bool
multiError bool
asreq, asrep bool // exclusive (XOR) fields

onceSettingDefaults sync.Once
defaultsSet func()
}

// FailFast returns schema validation errors quicker.
Expand All @@ -25,6 +32,11 @@ func VisitAsResponse() SchemaValidationOption {
return func(s *schemaValidationSettings) { s.asreq, s.asrep = false, true }
}

// DefaultsSet executes the given callback (once) IFF schema validation set default values.
func DefaultsSet(f func()) SchemaValidationOption {
return func(s *schemaValidationSettings) { s.defaultsSet = f }
}

func newSchemaValidationSettings(opts ...SchemaValidationOption) *schemaValidationSettings {
settings := &schemaValidationSettings{}
for _, opt := range opts {
Expand Down
14 changes: 9 additions & 5 deletions openapi3filter/req_resp_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,11 @@ const prefixUnsupportedCT = "unsupported content type"

// decodeBody returns a decoded body.
// The function returns ParseError when a body is invalid.
func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn) (interface{}, error) {
func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef, encFn EncodingFn) (
string,
interface{},
error,
) {
contentType := header.Get(headerCT)
if contentType == "" {
if _, ok := body.(*multipart.Part); ok {
Expand All @@ -878,16 +882,16 @@ func decodeBody(body io.Reader, header http.Header, schema *openapi3.SchemaRef,
mediaType := parseMediaType(contentType)
decoder, ok := bodyDecoders[mediaType]
if !ok {
return nil, &ParseError{
return "", nil, &ParseError{
Kind: KindUnsupportedFormat,
Reason: fmt.Sprintf("%s %q", prefixUnsupportedCT, mediaType),
}
}
value, err := decoder(body, header, schema, encFn)
if err != nil {
return nil, err
return "", nil, err
}
return value, nil
return mediaType, value, nil
}

func init() {
Expand Down Expand Up @@ -1036,7 +1040,7 @@ func multipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S
}

var value interface{}
if value, err = decodeBody(part, http.Header(part.Header), valueSchema, subEncFn); err != nil {
if _, value, err = decodeBody(part, http.Header(part.Header), valueSchema, subEncFn); err != nil {
if v, ok := err.(*ParseError); ok {
return nil, &ParseError{path: []interface{}{name}, Cause: v}
}
Expand Down
6 changes: 3 additions & 3 deletions openapi3filter/req_resp_decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1280,7 +1280,7 @@ func TestDecodeBody(t *testing.T) {
}
return tc.encoding[name]
}
got, err := decodeBody(tc.body, h, schemaRef, encFn)
_, got, err := decodeBody(tc.body, h, schemaRef, encFn)

if tc.wantErr != nil {
require.Error(t, err)
Expand Down Expand Up @@ -1350,7 +1350,7 @@ func TestRegisterAndUnregisterBodyDecoder(t *testing.T) {
body := strings.NewReader("foo,bar")
schema := openapi3.NewArraySchema().WithItems(openapi3.NewStringSchema()).NewRef()
encFn := func(string) *openapi3.Encoding { return nil }
got, err := decodeBody(body, h, schema, encFn)
_, got, err := decodeBody(body, h, schema, encFn)

require.NoError(t, err)
require.Equal(t, []string{"foo", "bar"}, got)
Expand All @@ -1360,7 +1360,7 @@ func TestRegisterAndUnregisterBodyDecoder(t *testing.T) {
originalDecoder = RegisteredBodyDecoder(contentType)
require.Nil(t, originalDecoder)

_, err = decodeBody(body, h, schema, encFn)
_, _, err = decodeBody(body, h, schema, encFn)
require.Equal(t, &ParseError{
Kind: KindUnsupportedFormat,
Reason: prefixUnsupportedCT + ` "text/csv"`,
Expand Down
27 changes: 27 additions & 0 deletions openapi3filter/req_resp_encoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package openapi3filter

import (
"encoding/json"
"fmt"
)

func encodeBody(body interface{}, mediaType string) ([]byte, error) {
encoder, ok := bodyEncoders[mediaType]
if !ok {
return nil, &ParseError{
Kind: KindUnsupportedFormat,
Reason: fmt.Sprintf("%s %q", prefixUnsupportedCT, mediaType),
}
}
return encoder(body)
}

type bodyEncoder func(body interface{}) ([]byte, error)

var bodyEncoders = map[string]bodyEncoder{
"application/json": jsonBodyEncoder,
}

func jsonBodyEncoder(body interface{}) ([]byte, error) {
return json.Marshal(body)
}
45 changes: 43 additions & 2 deletions openapi3filter/validate_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,30 @@ func ValidateParameter(ctx context.Context, input *RequestValidationInput, param
}
schema = parameter.Schema.Value
}

// Set default value if needed
if value == nil && schema != nil && schema.Default != nil {
value = schema.Default
req := input.Request
switch parameter.In {
case openapi3.ParameterInPath:
// TODO: no idea how to handle this
case openapi3.ParameterInQuery:
q := req.URL.Query()
q.Add(parameter.Name, fmt.Sprintf("%v", value))
req.URL.RawQuery = q.Encode()
case openapi3.ParameterInHeader:
req.Header.Add(parameter.Name, fmt.Sprintf("%v", value))
case openapi3.ParameterInCookie:
req.AddCookie(&http.Cookie{
Name: parameter.Name,
Value: fmt.Sprintf("%v", value),
})
default:
return fmt.Errorf("unsupported parameter's 'in': %s", parameter.In)
}
}

// Validate a parameter's value and presence.
if parameter.Required && !found {
return &RequestError{Input: input, Parameter: parameter, Reason: ErrInvalidRequired.Error(), Err: ErrInvalidRequired}
Expand Down Expand Up @@ -230,7 +254,7 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req
}

encFn := func(name string) *openapi3.Encoding { return contentType.Encoding[name] }
value, err := decodeBody(bytes.NewReader(data), req.Header, contentType.Schema, encFn)
mediaType, value, err := decodeBody(bytes.NewReader(data), req.Header, contentType.Schema, encFn)
if err != nil {
return &RequestError{
Input: input,
Expand All @@ -240,8 +264,10 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req
}
}

opts := make([]openapi3.SchemaValidationOption, 0, 2) // 2 potential opts here
defaultsSet := false
opts := make([]openapi3.SchemaValidationOption, 0, 3) // 3 potential opts here
opts = append(opts, openapi3.VisitAsRequest())
opts = append(opts, openapi3.DefaultsSet(func() { defaultsSet = true }))
if options.MultiError {
opts = append(opts, openapi3.MultiErrors())
}
Expand All @@ -255,6 +281,21 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req
Err: err,
}
}

if defaultsSet {
var err error
if data, err = encodeBody(value, mediaType); err != nil {
return &RequestError{
Input: input,
RequestBody: requestBody,
Reason: "rewriting failed",
Err: err,
}
}
// Put the data back into the input
req.Body = ioutil.NopCloser(bytes.NewReader(data))
}

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion openapi3filter/validate_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func ValidateResponse(ctx context.Context, input *ResponseValidationInput) error
input.SetBodyBytes(data)

encFn := func(name string) *openapi3.Encoding { return contentType.Encoding[name] }
value, err := decodeBody(bytes.NewBuffer(data), input.Header, contentType.Schema, encFn)
_, value, err := decodeBody(bytes.NewBuffer(data), input.Header, contentType.Schema, encFn)
if err != nil {
return &ResponseError{
Input: input,
Expand Down
Loading

0 comments on commit 7f8f768

Please sign in to comment.