Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAPI relative server URLs #230

Merged
merged 4 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions handler/transport/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
}

if b.openAPIValidator != nil {
if err = b.openAPIValidator.ValidateRequest(req); err != nil {
return nil, errors.BackendValidation.Label(b.name).With(err)
// FIXME tc.Origin should be an origin, not just a host!
if err = b.openAPIValidator.ValidateRequest(req, tc.hash(), tc.Scheme+"://"+tc.Origin); err != nil {
return nil, errors.BackendValidation.Label(b.name).Kind("backend_request_validation").With(err)
}
}

Expand Down Expand Up @@ -142,7 +143,8 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {

if b.openAPIValidator != nil {
if err = b.openAPIValidator.ValidateResponse(beresp); err != nil {
return nil, errors.BackendValidation.Label(b.name).With(err).Status(http.StatusBadGateway)
return nil, errors.BackendValidation.Label(b.name).Kind("backend_response_validation").
With(err).Status(http.StatusBadGateway)
}
}

Expand Down
20 changes: 11 additions & 9 deletions handler/transport/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,35 +197,35 @@ func TestBackend_RoundTrip_Validation(t *testing.T) {
http.MethodPost,
"/get",
"backend validation error",
"request validation: 'POST /get': Path doesn't support the HTTP method",
"'POST /get': Path doesn't support the HTTP method",
},
{
"invalid request, IgnoreRequestViolations",
&config.OpenAPI{File: "testdata/upstream.yaml", IgnoreRequestViolations: true, IgnoreResponseViolations: true},
http.MethodPost,
"/get",
"",
"request validation: 'POST /get': Path doesn't support the HTTP method",
"'POST /get': Path doesn't support the HTTP method",
},
{
"invalid response",
&config.OpenAPI{File: "testdata/upstream.yaml"},
http.MethodGet,
"/get?404",
"backend validation error",
"response validation: status is not supported",
"status is not supported",
},
{
"invalid response, IgnoreResponseViolations",
&config.OpenAPI{File: "testdata/upstream.yaml", IgnoreResponseViolations: true},
http.MethodGet,
"/get?404",
"",
"response validation: status is not supported",
"status is not supported",
},
}

logger, hook := logrustest.NewNullLogger()
logger, hook := test.NewLogger()
log := logger.WithContext(context.Background())

for _, tt := range tests {
Expand Down Expand Up @@ -260,13 +260,15 @@ func TestBackend_RoundTrip_Validation(t *testing.T) {
entry := hook.LastEntry()
if tt.expectedLogMessage != "" {
if data, ok := entry.Data["validation"]; ok {
for _, err := range data.([]string) {
if err == tt.expectedLogMessage {
for _, errStr := range data.([]string) {
if errStr != tt.expectedLogMessage {
subT.Errorf("\nwant:\t%s\ngot:\t%v", tt.expectedLogMessage, errStr)
return
}
return
}
for _, err := range data.([]string) {
subT.Log(err)
for _, errStr := range data.([]string) {
subT.Log(errStr)
}
}
subT.Errorf("expected matching validation error logs:\n\t%s\n\tgot: nothing", tt.expectedLogMessage)
Expand Down
75 changes: 64 additions & 11 deletions handler/validation/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@ import (
"io"
"io/ioutil"
"net/http"
"net/url"
"sync"

"github.com/avenga/couper/config/request"

"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/openapi3filter"

"github.com/avenga/couper/config/request"
"github.com/avenga/couper/eval"
)

var swaggers sync.Map

type OpenAPI struct {
options *OpenAPIOptions
requestValidationInput *openapi3filter.RequestValidationInput
Expand All @@ -28,10 +32,63 @@ func NewOpenAPI(opts *OpenAPIOptions) *OpenAPI {
}
}

func (v *OpenAPI) ValidateRequest(req *http.Request) error {
route, pathParams, err := v.options.router.FindRoute(req.Method, req.URL)
func (v *OpenAPI) getModifiedSwagger(key, origin string) (*openapi3.Swagger, error) {
swagger, exists := swaggers.Load(key)
if !exists {
clonedSwagger := cloneSwagger(v.options.swagger)

var newServers []string
for _, s := range clonedSwagger.Servers {
malud marked this conversation as resolved.
Show resolved Hide resolved
su, err := url.Parse(s.URL)
if err != nil {
return nil, err
}
if !su.IsAbs() {
newServers = append(newServers, origin+s.URL)
}
}
for _, ns := range newServers {
clonedSwagger.AddServer(&openapi3.Server{URL: ns})
}

swaggers.Store(key, clonedSwagger)
swagger = clonedSwagger
}

if s, ok := swagger.(*openapi3.Swagger); ok {
return s, nil
}

return nil, fmt.Errorf("swagger wrong type: %v", swagger)
}

func cloneSwagger(s *openapi3.Swagger) *openapi3.Swagger {
sw := *s
// this is not a deep clone; we only want to add servers
sw.Servers = s.Servers[:]
return &sw
}

func (v *OpenAPI) ValidateRequest(req *http.Request, key, origin string) error {
swagger, err := v.getModifiedSwagger(key, origin)
if err != nil {
if ctx, ok := req.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
ctx.errors = append(ctx.errors, err)
}
if !v.options.ignoreRequestViolations {
return err
}
return nil
}

router := openapi3filter.NewRouter()
if err = router.AddSwagger(swagger); err != nil {
return err
}

route, pathParams, err := router.FindRoute(req.Method, req.URL)
if err != nil {
err = fmt.Errorf("request validation: '%s %s': %w", req.Method, req.URL.Path, err)
err = fmt.Errorf("'%s %s': %w", req.Method, req.URL.Path, err)
if ctx, ok := req.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
ctx.errors = append(ctx.errors, err)
}
Expand All @@ -50,10 +107,7 @@ func (v *OpenAPI) ValidateRequest(req *http.Request) error {
}

// openapi3filter.ValidateRequestBody also handles resetting the req body after reading until EOF.
err = openapi3filter.ValidateRequest(req.Context(), v.requestValidationInput)

if err != nil {
err = fmt.Errorf("request validation: %w", err)
if err = openapi3filter.ValidateRequest(req.Context(), v.requestValidationInput); err != nil {
if ctx, ok := req.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
ctx.errors = append(ctx.errors, err)
}
Expand All @@ -68,7 +122,7 @@ func (v *OpenAPI) ValidateRequest(req *http.Request) error {
func (v *OpenAPI) ValidateResponse(beresp *http.Response) error {
// since a request validation could fail and ignored due to user options, the input route MAY be nil
if v.requestValidationInput == nil || v.requestValidationInput.Route == nil {
err := fmt.Errorf("response validation: '%s %s': invalid route", beresp.Request.Method, beresp.Request.URL.Path)
err := fmt.Errorf("'%s %s': invalid route", beresp.Request.Method, beresp.Request.URL.Path)
if beresp.Request != nil {
if ctx, ok := beresp.Request.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
ctx.errors = append(ctx.errors, err)
Expand Down Expand Up @@ -103,7 +157,6 @@ func (v *OpenAPI) ValidateResponse(beresp *http.Response) error {
}

if err := openapi3filter.ValidateResponse(beresp.Request.Context(), responseValidationInput); err != nil {
err = fmt.Errorf("response validation: %w", err)
if beresp.Request != nil {
if ctx, ok := beresp.Request.Context().Value(request.OpenAPI).(*OpenAPIContext); ok {
ctx.errors = append(ctx.errors, err)
Expand Down
28 changes: 24 additions & 4 deletions handler/validation/openapi_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package validation
import (
"fmt"
"io/ioutil"
"net/url"
"path/filepath"

"github.com/getkin/kin-openapi/openapi3"
Expand All @@ -17,7 +18,7 @@ type OpenAPIOptions struct {
ignoreRequestViolations bool
ignoreResponseViolations bool
filterOptions *openapi3filter.Options
router *openapi3filter.Router
swagger *openapi3.Swagger
}

// NewOpenAPIOptions takes a list of openAPI configuration due to merging configurations.
Expand All @@ -40,6 +41,26 @@ func NewOpenAPIOptions(openapi *config.OpenAPI) (*OpenAPIOptions, error) {
return NewOpenAPIOptionsFromBytes(openapi, b)
}

func canonicalizeServerURLs(swagger *openapi3.Swagger) error {
for _, server := range swagger.Servers {
su, err := url.Parse(server.URL)
if err != nil {
return err
}

if su.IsAbs() && su.Port() == "" && (su.Scheme == "https" || su.Scheme == "http") {
su.Host = su.Hostname() + ":"
if su.Scheme == "https" {
su.Host += "443"
} else {
su.Host += "80"
}
server.URL = su.String()
}
}
return nil
}

func NewOpenAPIOptionsFromBytes(openapi *config.OpenAPI, bytes []byte) (*OpenAPIOptions, error) {
if openapi == nil || bytes == nil {
return nil, nil
Expand All @@ -50,8 +71,7 @@ func NewOpenAPIOptionsFromBytes(openapi *config.OpenAPI, bytes []byte) (*OpenAPI
return nil, fmt.Errorf("error loading openapi file: %w", err)
}

router := openapi3filter.NewRouter()
if err = router.AddSwagger(swagger); err != nil {
if err = canonicalizeServerURLs(swagger); err != nil {
return nil, err
}

Expand All @@ -68,6 +88,6 @@ func NewOpenAPIOptionsFromBytes(openapi *config.OpenAPI, bytes []byte) (*OpenAPI
},
ignoreRequestViolations: openapi.IgnoreRequestViolations,
ignoreResponseViolations: openapi.IgnoreResponseViolations,
router: router,
swagger: swagger,
}, nil
}
Loading