diff --git a/handler/openapi_validator.go b/handler/openapi_validator.go index e635a875a..41248b8c5 100644 --- a/handler/openapi_validator.go +++ b/handler/openapi_validator.go @@ -5,9 +5,12 @@ import ( "io/ioutil" "net/http" "os" + "path/filepath" - "github.com/avenga/couper/config" + "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" + + "github.com/avenga/couper/config" ) type OpenAPIValidatorFactory struct { @@ -24,11 +27,29 @@ func NewOpenAPIValidatorFactory(openapi *config.OpenAPI) (*OpenAPIValidatorFacto if err != nil { return nil, err } - router := openapi3filter.NewRouter() - err = router.AddSwaggerFromFile(dir + "/" + openapi.File) + + bytes, err := ioutil.ReadFile(filepath.Join(dir, openapi.File)) + if err != nil { + return nil, err + } + return NewOpenAPIValidatorFactoryFromBytes(openapi, bytes) +} + +func NewOpenAPIValidatorFactoryFromBytes(openapi *config.OpenAPI, bytes []byte) (*OpenAPIValidatorFactory, error) { + if openapi == nil || bytes == nil { + return nil, nil + } + + swagger, err := openapi3.NewSwaggerLoader().LoadSwaggerFromData(bytes) if err != nil { return nil, err } + + router := openapi3filter.NewRouter() + if err = router.AddSwagger(swagger); err != nil { + return nil, err + } + return &OpenAPIValidatorFactory{ router: router, ignoreRequestViolations: openapi.IgnoreRequestViolations, diff --git a/handler/proxy.go b/handler/proxy.go index c1a515b88..8a2767155 100644 --- a/handler/proxy.go +++ b/handler/proxy.go @@ -11,7 +11,6 @@ import ( "net" "net/http" "net/url" - "os" "regexp" "strconv" "strings" diff --git a/handler/proxy_test.go b/handler/proxy_test.go index 3cc18e819..62cc6bc13 100644 --- a/handler/proxy_test.go +++ b/handler/proxy_test.go @@ -5,10 +5,8 @@ import ( "compress/gzip" "context" "io" - "io/ioutil" "net/http" "net/http/httptest" - "os" "reflect" "strings" "testing" @@ -425,9 +423,7 @@ paths: helper.Must(err) openapiYAML := &bytes.Buffer{} - openapiYAMLTemplate.Execute(openapiYAML, map[string]string{"url": origin.URL}) - helper.Must(ioutil.WriteFile("testdata/upstream.yaml", openapiYAML.Bytes(), 0644)) - defer helper.Must(os.Remove("testdata/upstream.yaml")) + helper.Must(openapiYAMLTemplate.Execute(openapiYAML, map[string]string{"url": origin.URL})) tests := []struct { name string @@ -482,11 +478,14 @@ paths: for _, tt := range tests { t.Run(tt.name, func(subT *testing.T) { logger, hook := logrustest.NewNullLogger() - openapiValidatorFactory, err := handler.NewOpenAPIValidatorFactory(tt.openapi) + openapiValidatorFactory, err := handler.NewOpenAPIValidatorFactoryFromBytes(tt.openapi, openapiYAML.Bytes()) if err != nil { subT.Fatal(err) } - p, err := handler.NewProxy(&handler.ProxyOptions{Origin: origin.URL, OpenAPI: openapiValidatorFactory}, logger.WithContext(context.Background()), nil, eval.NewENVContext(nil)) + content := helper.NewProxyContext(` + origin = "` + origin.URL + `" + `) + p, err := handler.NewProxy(&handler.ProxyOptions{Context: content, OpenAPI: openapiValidatorFactory}, logger.WithContext(context.Background()), nil, eval.NewENVContext(nil)) if err != nil { subT.Fatal(err) }