Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop sl.LoadSwaggerFromURIFunc #317

Merged
merged 6 commits into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,8 @@ func arrayUniqueItemsChecker(items []interface{}) bool {
// Check the uniqueness of the input slice(array in JSON)
}
```

## Sub-v0 breaking API changes

### v0.47.0
Field `(*openapi3.SwaggerLoader).LoadSwaggerFromURIFunc` of type `func(*openapi3.SwaggerLoader, *url.URL) (*openapi3.Swagger, error)` was removed after the addition of the field `(*openapi3.SwaggerLoader).ReadFromURIFunc` of type `func(*openapi3.SwaggerLoader, *url.URL) ([]byte, error)`.
121 changes: 53 additions & 68 deletions openapi3/swagger_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,14 @@ type SwaggerLoader struct {
// IsExternalRefsAllowed enables visiting other files
IsExternalRefsAllowed bool

// LoadSwaggerFromURIFunc allows overriding the swagger file/URL reading func
LoadSwaggerFromURIFunc func(loader *SwaggerLoader, url *url.URL) (*Swagger, error)

// ReadFromURIFunc allows overriding the any file/URL reading func
ReadFromURIFunc func(loader *SwaggerLoader, url *url.URL) ([]byte, error)

Context context.Context

visitedFiles map[string]struct{}
visitedSwaggers map[string]*Swagger
visitedPathItemRefs map[string]struct{}

visitedDocuments map[string]*Swagger

visitedExample map[*Example]struct{}
visitedHeader map[*Header]struct{}
Expand All @@ -55,22 +53,17 @@ func NewSwaggerLoader() *SwaggerLoader {
return &SwaggerLoader{}
}

func (swaggerLoader *SwaggerLoader) reset() {
swaggerLoader.visitedFiles = make(map[string]struct{})
swaggerLoader.visitedSwaggers = make(map[string]*Swagger)
func (swaggerLoader *SwaggerLoader) resetVisitedPathItemRefs() {
swaggerLoader.visitedPathItemRefs = make(map[string]struct{})
}

// LoadSwaggerFromURI loads a spec from a remote URL
func (swaggerLoader *SwaggerLoader) LoadSwaggerFromURI(location *url.URL) (*Swagger, error) {
swaggerLoader.reset()
swaggerLoader.resetVisitedPathItemRefs()
return swaggerLoader.loadSwaggerFromURIInternal(location)
}

func (swaggerLoader *SwaggerLoader) loadSwaggerFromURIInternal(location *url.URL) (*Swagger, error) {
f := swaggerLoader.LoadSwaggerFromURIFunc
if f != nil {
return f(swaggerLoader, location)
}
data, err := swaggerLoader.readURL(location)
if err != nil {
return nil, err
Expand Down Expand Up @@ -111,8 +104,7 @@ func (swaggerLoader *SwaggerLoader) loadSingleElementFromURI(ref string, rootPat
}

func (swaggerLoader *SwaggerLoader) readURL(location *url.URL) ([]byte, error) {
f := swaggerLoader.ReadFromURIFunc
if f != nil {
if f := swaggerLoader.ReadFromURIFunc; f != nil {
return f(swaggerLoader, location)
}

Expand All @@ -121,36 +113,23 @@ func (swaggerLoader *SwaggerLoader) readURL(location *url.URL) ([]byte, error) {
if err != nil {
return nil, err
}
data, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
if err != nil {
return nil, err
}
return data, nil
return ioutil.ReadAll(resp.Body)
}
if location.Scheme != "" || location.Host != "" || location.RawQuery != "" {
return nil, fmt.Errorf("unsupported URI: %q", location.String())
}
data, err := ioutil.ReadFile(location.Path)
if err != nil {
return nil, err
}
return data, nil
return ioutil.ReadFile(location.Path)
}

// LoadSwaggerFromFile loads a spec from a local file path
func (swaggerLoader *SwaggerLoader) LoadSwaggerFromFile(path string) (*Swagger, error) {
swaggerLoader.reset()
swaggerLoader.resetVisitedPathItemRefs()
return swaggerLoader.loadSwaggerFromFileInternal(path)
}

func (swaggerLoader *SwaggerLoader) loadSwaggerFromFileInternal(path string) (*Swagger, error) {
f := swaggerLoader.LoadSwaggerFromURIFunc
pathAsURL := &url.URL{Path: path}
if f != nil {
x, err := f(swaggerLoader, pathAsURL)
return x, err
}
data, err := swaggerLoader.readURL(pathAsURL)
if err != nil {
return nil, err
Expand All @@ -160,33 +139,39 @@ func (swaggerLoader *SwaggerLoader) loadSwaggerFromFileInternal(path string) (*S

// LoadSwaggerFromData loads a spec from a byte array
func (swaggerLoader *SwaggerLoader) LoadSwaggerFromData(data []byte) (*Swagger, error) {
swaggerLoader.reset()
swaggerLoader.resetVisitedPathItemRefs()
return swaggerLoader.loadSwaggerFromDataInternal(data)
}

func (swaggerLoader *SwaggerLoader) loadSwaggerFromDataInternal(data []byte) (*Swagger, error) {
swagger := &Swagger{}
if err := yaml.Unmarshal(data, swagger); err != nil {
doc := &Swagger{}
if err := yaml.Unmarshal(data, doc); err != nil {
return nil, err
}
if err := swaggerLoader.ResolveRefsIn(doc, nil); err != nil {
return nil, err
}
return swagger, swaggerLoader.ResolveRefsIn(swagger, nil)
return doc, nil
}

// LoadSwaggerFromDataWithPath takes the OpenApi spec data in bytes and a path where the resolver can find referred
// elements and returns a *Swagger with all resolved data or an error if unable to load data or resolve refs.
func (swaggerLoader *SwaggerLoader) LoadSwaggerFromDataWithPath(data []byte, path *url.URL) (*Swagger, error) {
swaggerLoader.reset()
swaggerLoader.resetVisitedPathItemRefs()
return swaggerLoader.loadSwaggerFromDataWithPathInternal(data, path)
}

func (swaggerLoader *SwaggerLoader) loadSwaggerFromDataWithPathInternal(data []byte, path *url.URL) (*Swagger, error) {
visited, ok := swaggerLoader.visitedSwaggers[path.String()]
if ok {
return visited, nil
if swaggerLoader.visitedDocuments == nil {
swaggerLoader.visitedDocuments = make(map[string]*Swagger)
}
uri := path.String()
if doc, ok := swaggerLoader.visitedDocuments[uri]; ok {
return doc, nil
}

swagger := &Swagger{}
swaggerLoader.visitedSwaggers[path.String()] = swagger
swaggerLoader.visitedDocuments[uri] = swagger

if err := yaml.Unmarshal(data, swagger); err != nil {
return nil, err
Expand All @@ -200,32 +185,8 @@ func (swaggerLoader *SwaggerLoader) loadSwaggerFromDataWithPathInternal(data []b

// ResolveRefsIn expands references if for instance spec was just unmarshalled
func (swaggerLoader *SwaggerLoader) ResolveRefsIn(swagger *Swagger, path *url.URL) (err error) {
if swaggerLoader.visitedExample == nil {
swaggerLoader.visitedExample = make(map[*Example]struct{})
}
if swaggerLoader.visitedHeader == nil {
swaggerLoader.visitedHeader = make(map[*Header]struct{})
}
if swaggerLoader.visitedLink == nil {
swaggerLoader.visitedLink = make(map[*Link]struct{})
}
if swaggerLoader.visitedParameter == nil {
swaggerLoader.visitedParameter = make(map[*Parameter]struct{})
}
if swaggerLoader.visitedRequestBody == nil {
swaggerLoader.visitedRequestBody = make(map[*RequestBody]struct{})
}
if swaggerLoader.visitedResponse == nil {
swaggerLoader.visitedResponse = make(map[*Response]struct{})
}
if swaggerLoader.visitedSchema == nil {
swaggerLoader.visitedSchema = make(map[*Schema]struct{})
}
if swaggerLoader.visitedSecurityScheme == nil {
swaggerLoader.visitedSecurityScheme = make(map[*SecurityScheme]struct{})
}
if swaggerLoader.visitedFiles == nil {
swaggerLoader.reset()
if swaggerLoader.visitedPathItemRefs == nil {
swaggerLoader.resetVisitedPathItemRefs()
}

// Visit all components
Expand Down Expand Up @@ -456,6 +417,9 @@ func (swaggerLoader *SwaggerLoader) resolveRefSwagger(swagger *Swagger, ref stri

func (swaggerLoader *SwaggerLoader) resolveHeaderRef(swagger *Swagger, component *HeaderRef, documentPath *url.URL) error {
if component != nil && component.Value != nil {
if swaggerLoader.visitedHeader == nil {
swaggerLoader.visitedHeader = make(map[*Header]struct{})
}
if _, ok := swaggerLoader.visitedHeader[component.Value]; ok {
return nil
}
Expand Down Expand Up @@ -500,6 +464,9 @@ func (swaggerLoader *SwaggerLoader) resolveHeaderRef(swagger *Swagger, component

func (swaggerLoader *SwaggerLoader) resolveParameterRef(swagger *Swagger, component *ParameterRef, documentPath *url.URL) error {
if component != nil && component.Value != nil {
if swaggerLoader.visitedParameter == nil {
swaggerLoader.visitedParameter = make(map[*Parameter]struct{})
}
if _, ok := swaggerLoader.visitedParameter[component.Value]; ok {
return nil
}
Expand Down Expand Up @@ -560,6 +527,9 @@ func (swaggerLoader *SwaggerLoader) resolveParameterRef(swagger *Swagger, compon

func (swaggerLoader *SwaggerLoader) resolveRequestBodyRef(swagger *Swagger, component *RequestBodyRef, documentPath *url.URL) error {
if component != nil && component.Value != nil {
if swaggerLoader.visitedRequestBody == nil {
swaggerLoader.visitedRequestBody = make(map[*RequestBody]struct{})
}
if _, ok := swaggerLoader.visitedRequestBody[component.Value]; ok {
return nil
}
Expand Down Expand Up @@ -612,6 +582,9 @@ func (swaggerLoader *SwaggerLoader) resolveRequestBodyRef(swagger *Swagger, comp

func (swaggerLoader *SwaggerLoader) resolveResponseRef(swagger *Swagger, component *ResponseRef, documentPath *url.URL) error {
if component != nil && component.Value != nil {
if swaggerLoader.visitedResponse == nil {
swaggerLoader.visitedResponse = make(map[*Response]struct{})
}
if _, ok := swaggerLoader.visitedResponse[component.Value]; ok {
return nil
}
Expand Down Expand Up @@ -683,6 +656,9 @@ func (swaggerLoader *SwaggerLoader) resolveResponseRef(swagger *Swagger, compone

func (swaggerLoader *SwaggerLoader) resolveSchemaRef(swagger *Swagger, component *SchemaRef, documentPath *url.URL) error {
if component != nil && component.Value != nil {
if swaggerLoader.visitedSchema == nil {
swaggerLoader.visitedSchema = make(map[*Schema]struct{})
}
if _, ok := swaggerLoader.visitedSchema[component.Value]; ok {
return nil
}
Expand Down Expand Up @@ -766,6 +742,9 @@ func (swaggerLoader *SwaggerLoader) resolveSchemaRef(swagger *Swagger, component

func (swaggerLoader *SwaggerLoader) resolveSecuritySchemeRef(swagger *Swagger, component *SecuritySchemeRef, documentPath *url.URL) error {
if component != nil && component.Value != nil {
if swaggerLoader.visitedSecurityScheme == nil {
swaggerLoader.visitedSecurityScheme = make(map[*SecurityScheme]struct{})
}
if _, ok := swaggerLoader.visitedSecurityScheme[component.Value]; ok {
return nil
}
Expand Down Expand Up @@ -801,6 +780,9 @@ func (swaggerLoader *SwaggerLoader) resolveSecuritySchemeRef(swagger *Swagger, c

func (swaggerLoader *SwaggerLoader) resolveExampleRef(swagger *Swagger, component *ExampleRef, documentPath *url.URL) error {
if component != nil && component.Value != nil {
if swaggerLoader.visitedExample == nil {
swaggerLoader.visitedExample = make(map[*Example]struct{})
}
if _, ok := swaggerLoader.visitedExample[component.Value]; ok {
return nil
}
Expand Down Expand Up @@ -836,6 +818,9 @@ func (swaggerLoader *SwaggerLoader) resolveExampleRef(swagger *Swagger, componen

func (swaggerLoader *SwaggerLoader) resolveLinkRef(swagger *Swagger, component *LinkRef, documentPath *url.URL) error {
if component != nil && component.Value != nil {
if swaggerLoader.visitedLink == nil {
swaggerLoader.visitedLink = make(map[*Link]struct{})
}
if _, ok := swaggerLoader.visitedLink[component.Value]; ok {
return nil
}
Expand Down Expand Up @@ -875,10 +860,10 @@ func (swaggerLoader *SwaggerLoader) resolvePathItemRef(swagger *Swagger, entrypo
key = documentPath.EscapedPath()
}
key += entrypoint
if _, ok := swaggerLoader.visitedFiles[key]; ok {
if _, ok := swaggerLoader.visitedPathItemRefs[key]; ok {
return nil
}
swaggerLoader.visitedFiles[key] = struct{}{}
swaggerLoader.visitedPathItemRefs[key] = struct{}{}

const prefix = "#/paths/"
if pathItem == nil {
Expand Down
50 changes: 50 additions & 0 deletions openapi3/swagger_loader_read_from_uri_func_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openapi3

import (
"fmt"
"io/ioutil"
"net/url"
"path/filepath"
Expand All @@ -21,3 +22,52 @@ func TestLoaderReadFromURIFunc(t *testing.T) {
require.NoError(t, doc.Validate(loader.Context))
require.Equal(t, "bar", doc.Paths["/foo"].Get.Responses.Get(200).Value.Content.Get("application/json").Schema.Value.Properties["foo"].Value.Properties["bar"].Value.Items.Value.Example)
}

type multipleSourceSwaggerLoaderExample struct {
Sources map[string][]byte
}

func (l *multipleSourceSwaggerLoaderExample) LoadSwaggerFromURI(
loader *SwaggerLoader,
location *url.URL,
) ([]byte, error) {
source := l.resolveSourceFromURI(location)
if source == nil {
return nil, fmt.Errorf("Unsupported URI: %q", location.String())
}
return source, nil
}

func (l *multipleSourceSwaggerLoaderExample) resolveSourceFromURI(location fmt.Stringer) []byte {
return l.Sources[location.String()]
}

func TestResolveSchemaExternalRef(t *testing.T) {
rootLocation := &url.URL{Scheme: "http", Host: "example.com", Path: "spec.json"}
externalLocation := &url.URL{Scheme: "http", Host: "example.com", Path: "external.json"}
rootSpec := []byte(fmt.Sprintf(
`{"openapi":"3.0.0","info":{"title":"MyAPI","version":"0.1","description":"An API"},"paths":{},"components":{"schemas":{"Root":{"allOf":[{"$ref":"%s#/components/schemas/External"}]}}}}`,
externalLocation.String(),
))
externalSpec := []byte(`{"openapi":"3.0.0","info":{"title":"MyAPI","version":"0.1","description":"External Spec"},"paths":{},"components":{"schemas":{"External":{"type":"string"}}}}`)
multipleSourceLoader := &multipleSourceSwaggerLoaderExample{
Sources: map[string][]byte{
rootLocation.String(): rootSpec,
externalLocation.String(): externalSpec,
},
}
loader := &SwaggerLoader{
IsExternalRefsAllowed: true,
ReadFromURIFunc: multipleSourceLoader.LoadSwaggerFromURI,
}

doc, err := loader.LoadSwaggerFromURI(rootLocation)
require.NoError(t, err)

err = doc.Validate(loader.Context)
require.NoError(t, err)

refRootVisited := doc.Components.Schemas["Root"].Value.AllOf[0]
require.Equal(t, fmt.Sprintf("%s#/components/schemas/External", externalLocation.String()), refRootVisited.Ref)
require.NotNil(t, refRootVisited.Value)
}
Loading