Skip to content

Commit

Permalink
Bypass any file/URL reading by ReadFromURIFunc (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
hottestseason authored Mar 3, 2021
1 parent 49752fc commit ea43ca7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
18 changes: 13 additions & 5 deletions openapi3/swagger_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ type SwaggerLoader struct {
// IsExternalRefsAllowed enables visiting other files
IsExternalRefsAllowed bool

// LoadSwaggerFromURIFunc allows overriding the file/URL reading func
// LoadSwaggerFromURIFunc allows overriding the swagger file/URL reading func
LoadSwaggerFromURIFunc func(loader *SwaggerLoader, url *url.URL) (*Swagger, error)

// ReadFromURIFunc allows overriding the any file/URL reading func
ReadFromURIFunc func(loader *SwaggerLoader, url *url.URL) ([]byte, error)

Context context.Context

visitedFiles map[string]struct{}
Expand Down Expand Up @@ -68,7 +71,7 @@ func (swaggerLoader *SwaggerLoader) loadSwaggerFromURIInternal(location *url.URL
if f != nil {
return f(swaggerLoader, location)
}
data, err := readURL(location)
data, err := swaggerLoader.readURL(location)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -96,7 +99,7 @@ func (swaggerLoader *SwaggerLoader) loadSingleElementFromURI(ref string, rootPat
return fmt.Errorf("could not resolve path: %v", err)
}

data, err := readURL(resolvedPath)
data, err := swaggerLoader.readURL(resolvedPath)
if err != nil {
return err
}
Expand All @@ -107,7 +110,12 @@ func (swaggerLoader *SwaggerLoader) loadSingleElementFromURI(ref string, rootPat
return nil
}

func readURL(location *url.URL) ([]byte, error) {
func (swaggerLoader *SwaggerLoader) readURL(location *url.URL) ([]byte, error) {
f := swaggerLoader.ReadFromURIFunc
if f != nil {
return f(swaggerLoader, location)
}

if location.Scheme != "" && location.Host != "" {
resp, err := http.Get(location.String())
if err != nil {
Expand Down Expand Up @@ -143,7 +151,7 @@ func (swaggerLoader *SwaggerLoader) loadSwaggerFromFileInternal(path string) (*S
x, err := f(swaggerLoader, pathAsURL)
return x, err
}
data, err := ioutil.ReadFile(path)
data, err := swaggerLoader.readURL(pathAsURL)
if err != nil {
return nil, err
}
Expand Down
23 changes: 23 additions & 0 deletions openapi3/swagger_loader_read_from_uri_func_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package openapi3

import (
"io/ioutil"
"net/url"
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

func TestLoaderReadFromURIFunc(t *testing.T) {
loader := NewSwaggerLoader()
loader.IsExternalRefsAllowed = true
loader.ReadFromURIFunc = func(loader *SwaggerLoader, url *url.URL) ([]byte, error) {
return ioutil.ReadFile(filepath.Join("testdata", url.Path))
}
doc, err := loader.LoadSwaggerFromFile("recursiveRef/openapi.yml")
require.NoError(t, err)
require.NotNil(t, doc)
require.NoError(t, doc.Validate(loader.Context))
require.Equal(t, "bar", doc.Paths["/foo"].Get.Responses.Get(200).Value.Content.Get("application/json").Schema.Value.Properties["foo"].Value.Properties["bar"].Value.Items.Value.Example)
}

0 comments on commit ea43ca7

Please sign in to comment.