Skip to content

Commit

Permalink
Use and update GetBody() member of request (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShouheiNishi authored Dec 14, 2022
1 parent 8718011 commit f136047
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
23 changes: 21 additions & 2 deletions openapi3filter/validate_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"sort"
Expand Down Expand Up @@ -216,7 +217,19 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req
}
}
// Put the data back into the input
req.Body = ioutil.NopCloser(bytes.NewReader(data))
req.Body = nil
if req.GetBody != nil {
if req.Body, err = req.GetBody(); err != nil {
req.Body = nil
}
}
if req.Body == nil {
req.ContentLength = int64(len(data))
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(data)), nil
}
req.Body, _ = req.GetBody() // no error return
}
}

if len(data) == 0 {
Expand Down Expand Up @@ -292,8 +305,14 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req
}
}
// Put the data back into the input
req.Body = ioutil.NopCloser(bytes.NewReader(data))
if req.Body != nil {
req.Body.Close()
}
req.ContentLength = int64(len(data))
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(data)), nil
}
req.Body, _ = req.GetBody() // no error return
}

return nil
Expand Down
6 changes: 6 additions & 0 deletions openapi3filter/validate_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ components:
assert.Equal(t, contentLen, bodySize, "expect ContentLength %d to equal body size %d", contentLen, bodySize)
bodyModified := originalBodySize != bodySize
assert.Equal(t, bodyModified, tc.expectedModification, "expect request body modification happened: %t, expected %t", bodyModified, tc.expectedModification)

validationInput.Request.Body, err = validationInput.Request.GetBody()
assert.NoError(t, err, "unable to re-generate body by GetBody(): %v", err)
body2, err := io.ReadAll(validationInput.Request.Body)
assert.NoError(t, err, "unable to read request body: %v", err)
assert.Equal(t, body, body2, "body by GetBody() is not matched")
})
}
}

0 comments on commit f136047

Please sign in to comment.