From 3d0d8d92eefcfb372e3dc9c173c7271cad9453e9 Mon Sep 17 00:00:00 2001 From: Sumihiko Natsu Date: Wed, 3 Mar 2021 09:43:13 +0900 Subject: [PATCH] Bypass any file/URL reading by ReadFromURIFunc --- openapi3/swagger_loader.go | 18 +++++++++++---- .../swagger_loader_read_from_uri_func_test.go | 23 +++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) create mode 100644 openapi3/swagger_loader_read_from_uri_func_test.go diff --git a/openapi3/swagger_loader.go b/openapi3/swagger_loader.go index 6755b2917..db0eb44be 100644 --- a/openapi3/swagger_loader.go +++ b/openapi3/swagger_loader.go @@ -29,9 +29,12 @@ type SwaggerLoader struct { // IsExternalRefsAllowed enables visiting other files IsExternalRefsAllowed bool - // LoadSwaggerFromURIFunc allows overriding the file/URL reading func + // LoadSwaggerFromURIFunc allows overriding the swagger file/URL reading func LoadSwaggerFromURIFunc func(loader *SwaggerLoader, url *url.URL) (*Swagger, error) + // ReadFromURIFunc allows overriding the any file/URL reading func + ReadFromURIFunc func(loader *SwaggerLoader, url *url.URL) ([]byte, error) + Context context.Context visitedFiles map[string]struct{} @@ -68,7 +71,7 @@ func (swaggerLoader *SwaggerLoader) loadSwaggerFromURIInternal(location *url.URL if f != nil { return f(swaggerLoader, location) } - data, err := readURL(location) + data, err := swaggerLoader.readURL(location) if err != nil { return nil, err } @@ -96,7 +99,7 @@ func (swaggerLoader *SwaggerLoader) loadSingleElementFromURI(ref string, rootPat return fmt.Errorf("could not resolve path: %v", err) } - data, err := readURL(resolvedPath) + data, err := swaggerLoader.readURL(resolvedPath) if err != nil { return err } @@ -107,7 +110,12 @@ func (swaggerLoader *SwaggerLoader) loadSingleElementFromURI(ref string, rootPat return nil } -func readURL(location *url.URL) ([]byte, error) { +func (swaggerLoader *SwaggerLoader) readURL(location *url.URL) ([]byte, error) { + f := swaggerLoader.ReadFromURIFunc + if f != nil { + return f(swaggerLoader, location) + } + if location.Scheme != "" && location.Host != "" { resp, err := http.Get(location.String()) if err != nil { @@ -143,7 +151,7 @@ func (swaggerLoader *SwaggerLoader) loadSwaggerFromFileInternal(path string) (*S x, err := f(swaggerLoader, pathAsURL) return x, err } - data, err := ioutil.ReadFile(path) + data, err := swaggerLoader.readURL(pathAsURL) if err != nil { return nil, err } diff --git a/openapi3/swagger_loader_read_from_uri_func_test.go b/openapi3/swagger_loader_read_from_uri_func_test.go new file mode 100644 index 000000000..3ac0ed74f --- /dev/null +++ b/openapi3/swagger_loader_read_from_uri_func_test.go @@ -0,0 +1,23 @@ +package openapi3 + +import ( + "io/ioutil" + "net/url" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLoaderReadFromURIFunc(t *testing.T) { + loader := NewSwaggerLoader() + loader.IsExternalRefsAllowed = true + loader.ReadFromURIFunc = func(loader *SwaggerLoader, url *url.URL) ([]byte, error) { + return ioutil.ReadFile(filepath.Join("testdata", url.Path)) + } + doc, err := loader.LoadSwaggerFromFile("recursiveRef/openapi.yml") + require.NoError(t, err) + require.NotNil(t, doc) + require.NoError(t, doc.Validate(loader.Context)) + require.Equal(t, "bar", doc.Paths["/foo"].Get.Responses.Get(200).Value.Content.Get("application/json").Schema.Value.Properties["foo"].Value.Properties["bar"].Value.Items.Value.Example) +}