diff --git a/README.md b/README.md index 979c08164..c1761a6ab 100644 --- a/README.md +++ b/README.md @@ -197,3 +197,8 @@ func arrayUniqueItemsChecker(items []interface{}) bool { // Check the uniqueness of the input slice(array in JSON) } ``` + +## Sub-v0 breaking API changes + +### v0.47.0 +Field `(*openapi3.SwaggerLoader).LoadSwaggerFromURIFunc` of type `func(*openapi3.SwaggerLoader, *url.URL) (*openapi3.Swagger, error)` was removed after the addition of the field `(*openapi3.SwaggerLoader).ReadFromURIFunc` of type `func(*openapi3.SwaggerLoader, *url.URL) ([]byte, error)`. diff --git a/openapi3/swagger_loader.go b/openapi3/swagger_loader.go index db0eb44be..3bb902dff 100644 --- a/openapi3/swagger_loader.go +++ b/openapi3/swagger_loader.go @@ -29,16 +29,14 @@ type SwaggerLoader struct { // IsExternalRefsAllowed enables visiting other files IsExternalRefsAllowed bool - // LoadSwaggerFromURIFunc allows overriding the swagger file/URL reading func - LoadSwaggerFromURIFunc func(loader *SwaggerLoader, url *url.URL) (*Swagger, error) - // ReadFromURIFunc allows overriding the any file/URL reading func ReadFromURIFunc func(loader *SwaggerLoader, url *url.URL) ([]byte, error) Context context.Context - visitedFiles map[string]struct{} - visitedSwaggers map[string]*Swagger + visitedPathItemRefs map[string]struct{} + + visitedDocuments map[string]*Swagger visitedExample map[*Example]struct{} visitedHeader map[*Header]struct{} @@ -55,22 +53,17 @@ func NewSwaggerLoader() *SwaggerLoader { return &SwaggerLoader{} } -func (swaggerLoader *SwaggerLoader) reset() { - swaggerLoader.visitedFiles = make(map[string]struct{}) - swaggerLoader.visitedSwaggers = make(map[string]*Swagger) +func (swaggerLoader *SwaggerLoader) resetVisitedPathItemRefs() { + swaggerLoader.visitedPathItemRefs = make(map[string]struct{}) } // LoadSwaggerFromURI loads a spec from a remote URL func (swaggerLoader *SwaggerLoader) LoadSwaggerFromURI(location *url.URL) (*Swagger, error) { - swaggerLoader.reset() + swaggerLoader.resetVisitedPathItemRefs() return swaggerLoader.loadSwaggerFromURIInternal(location) } func (swaggerLoader *SwaggerLoader) loadSwaggerFromURIInternal(location *url.URL) (*Swagger, error) { - f := swaggerLoader.LoadSwaggerFromURIFunc - if f != nil { - return f(swaggerLoader, location) - } data, err := swaggerLoader.readURL(location) if err != nil { return nil, err @@ -111,8 +104,7 @@ func (swaggerLoader *SwaggerLoader) loadSingleElementFromURI(ref string, rootPat } func (swaggerLoader *SwaggerLoader) readURL(location *url.URL) ([]byte, error) { - f := swaggerLoader.ReadFromURIFunc - if f != nil { + if f := swaggerLoader.ReadFromURIFunc; f != nil { return f(swaggerLoader, location) } @@ -121,36 +113,23 @@ func (swaggerLoader *SwaggerLoader) readURL(location *url.URL) ([]byte, error) { if err != nil { return nil, err } - data, err := ioutil.ReadAll(resp.Body) defer resp.Body.Close() - if err != nil { - return nil, err - } - return data, nil + return ioutil.ReadAll(resp.Body) } if location.Scheme != "" || location.Host != "" || location.RawQuery != "" { return nil, fmt.Errorf("unsupported URI: %q", location.String()) } - data, err := ioutil.ReadFile(location.Path) - if err != nil { - return nil, err - } - return data, nil + return ioutil.ReadFile(location.Path) } // LoadSwaggerFromFile loads a spec from a local file path func (swaggerLoader *SwaggerLoader) LoadSwaggerFromFile(path string) (*Swagger, error) { - swaggerLoader.reset() + swaggerLoader.resetVisitedPathItemRefs() return swaggerLoader.loadSwaggerFromFileInternal(path) } func (swaggerLoader *SwaggerLoader) loadSwaggerFromFileInternal(path string) (*Swagger, error) { - f := swaggerLoader.LoadSwaggerFromURIFunc pathAsURL := &url.URL{Path: path} - if f != nil { - x, err := f(swaggerLoader, pathAsURL) - return x, err - } data, err := swaggerLoader.readURL(pathAsURL) if err != nil { return nil, err @@ -160,33 +139,39 @@ func (swaggerLoader *SwaggerLoader) loadSwaggerFromFileInternal(path string) (*S // LoadSwaggerFromData loads a spec from a byte array func (swaggerLoader *SwaggerLoader) LoadSwaggerFromData(data []byte) (*Swagger, error) { - swaggerLoader.reset() + swaggerLoader.resetVisitedPathItemRefs() return swaggerLoader.loadSwaggerFromDataInternal(data) } func (swaggerLoader *SwaggerLoader) loadSwaggerFromDataInternal(data []byte) (*Swagger, error) { - swagger := &Swagger{} - if err := yaml.Unmarshal(data, swagger); err != nil { + doc := &Swagger{} + if err := yaml.Unmarshal(data, doc); err != nil { + return nil, err + } + if err := swaggerLoader.ResolveRefsIn(doc, nil); err != nil { return nil, err } - return swagger, swaggerLoader.ResolveRefsIn(swagger, nil) + return doc, nil } // LoadSwaggerFromDataWithPath takes the OpenApi spec data in bytes and a path where the resolver can find referred // elements and returns a *Swagger with all resolved data or an error if unable to load data or resolve refs. func (swaggerLoader *SwaggerLoader) LoadSwaggerFromDataWithPath(data []byte, path *url.URL) (*Swagger, error) { - swaggerLoader.reset() + swaggerLoader.resetVisitedPathItemRefs() return swaggerLoader.loadSwaggerFromDataWithPathInternal(data, path) } func (swaggerLoader *SwaggerLoader) loadSwaggerFromDataWithPathInternal(data []byte, path *url.URL) (*Swagger, error) { - visited, ok := swaggerLoader.visitedSwaggers[path.String()] - if ok { - return visited, nil + if swaggerLoader.visitedDocuments == nil { + swaggerLoader.visitedDocuments = make(map[string]*Swagger) + } + uri := path.String() + if doc, ok := swaggerLoader.visitedDocuments[uri]; ok { + return doc, nil } swagger := &Swagger{} - swaggerLoader.visitedSwaggers[path.String()] = swagger + swaggerLoader.visitedDocuments[uri] = swagger if err := yaml.Unmarshal(data, swagger); err != nil { return nil, err @@ -200,32 +185,8 @@ func (swaggerLoader *SwaggerLoader) loadSwaggerFromDataWithPathInternal(data []b // ResolveRefsIn expands references if for instance spec was just unmarshalled func (swaggerLoader *SwaggerLoader) ResolveRefsIn(swagger *Swagger, path *url.URL) (err error) { - if swaggerLoader.visitedExample == nil { - swaggerLoader.visitedExample = make(map[*Example]struct{}) - } - if swaggerLoader.visitedHeader == nil { - swaggerLoader.visitedHeader = make(map[*Header]struct{}) - } - if swaggerLoader.visitedLink == nil { - swaggerLoader.visitedLink = make(map[*Link]struct{}) - } - if swaggerLoader.visitedParameter == nil { - swaggerLoader.visitedParameter = make(map[*Parameter]struct{}) - } - if swaggerLoader.visitedRequestBody == nil { - swaggerLoader.visitedRequestBody = make(map[*RequestBody]struct{}) - } - if swaggerLoader.visitedResponse == nil { - swaggerLoader.visitedResponse = make(map[*Response]struct{}) - } - if swaggerLoader.visitedSchema == nil { - swaggerLoader.visitedSchema = make(map[*Schema]struct{}) - } - if swaggerLoader.visitedSecurityScheme == nil { - swaggerLoader.visitedSecurityScheme = make(map[*SecurityScheme]struct{}) - } - if swaggerLoader.visitedFiles == nil { - swaggerLoader.reset() + if swaggerLoader.visitedPathItemRefs == nil { + swaggerLoader.resetVisitedPathItemRefs() } // Visit all components @@ -456,6 +417,9 @@ func (swaggerLoader *SwaggerLoader) resolveRefSwagger(swagger *Swagger, ref stri func (swaggerLoader *SwaggerLoader) resolveHeaderRef(swagger *Swagger, component *HeaderRef, documentPath *url.URL) error { if component != nil && component.Value != nil { + if swaggerLoader.visitedHeader == nil { + swaggerLoader.visitedHeader = make(map[*Header]struct{}) + } if _, ok := swaggerLoader.visitedHeader[component.Value]; ok { return nil } @@ -500,6 +464,9 @@ func (swaggerLoader *SwaggerLoader) resolveHeaderRef(swagger *Swagger, component func (swaggerLoader *SwaggerLoader) resolveParameterRef(swagger *Swagger, component *ParameterRef, documentPath *url.URL) error { if component != nil && component.Value != nil { + if swaggerLoader.visitedParameter == nil { + swaggerLoader.visitedParameter = make(map[*Parameter]struct{}) + } if _, ok := swaggerLoader.visitedParameter[component.Value]; ok { return nil } @@ -560,6 +527,9 @@ func (swaggerLoader *SwaggerLoader) resolveParameterRef(swagger *Swagger, compon func (swaggerLoader *SwaggerLoader) resolveRequestBodyRef(swagger *Swagger, component *RequestBodyRef, documentPath *url.URL) error { if component != nil && component.Value != nil { + if swaggerLoader.visitedRequestBody == nil { + swaggerLoader.visitedRequestBody = make(map[*RequestBody]struct{}) + } if _, ok := swaggerLoader.visitedRequestBody[component.Value]; ok { return nil } @@ -612,6 +582,9 @@ func (swaggerLoader *SwaggerLoader) resolveRequestBodyRef(swagger *Swagger, comp func (swaggerLoader *SwaggerLoader) resolveResponseRef(swagger *Swagger, component *ResponseRef, documentPath *url.URL) error { if component != nil && component.Value != nil { + if swaggerLoader.visitedResponse == nil { + swaggerLoader.visitedResponse = make(map[*Response]struct{}) + } if _, ok := swaggerLoader.visitedResponse[component.Value]; ok { return nil } @@ -683,6 +656,9 @@ func (swaggerLoader *SwaggerLoader) resolveResponseRef(swagger *Swagger, compone func (swaggerLoader *SwaggerLoader) resolveSchemaRef(swagger *Swagger, component *SchemaRef, documentPath *url.URL) error { if component != nil && component.Value != nil { + if swaggerLoader.visitedSchema == nil { + swaggerLoader.visitedSchema = make(map[*Schema]struct{}) + } if _, ok := swaggerLoader.visitedSchema[component.Value]; ok { return nil } @@ -766,6 +742,9 @@ func (swaggerLoader *SwaggerLoader) resolveSchemaRef(swagger *Swagger, component func (swaggerLoader *SwaggerLoader) resolveSecuritySchemeRef(swagger *Swagger, component *SecuritySchemeRef, documentPath *url.URL) error { if component != nil && component.Value != nil { + if swaggerLoader.visitedSecurityScheme == nil { + swaggerLoader.visitedSecurityScheme = make(map[*SecurityScheme]struct{}) + } if _, ok := swaggerLoader.visitedSecurityScheme[component.Value]; ok { return nil } @@ -801,6 +780,9 @@ func (swaggerLoader *SwaggerLoader) resolveSecuritySchemeRef(swagger *Swagger, c func (swaggerLoader *SwaggerLoader) resolveExampleRef(swagger *Swagger, component *ExampleRef, documentPath *url.URL) error { if component != nil && component.Value != nil { + if swaggerLoader.visitedExample == nil { + swaggerLoader.visitedExample = make(map[*Example]struct{}) + } if _, ok := swaggerLoader.visitedExample[component.Value]; ok { return nil } @@ -836,6 +818,9 @@ func (swaggerLoader *SwaggerLoader) resolveExampleRef(swagger *Swagger, componen func (swaggerLoader *SwaggerLoader) resolveLinkRef(swagger *Swagger, component *LinkRef, documentPath *url.URL) error { if component != nil && component.Value != nil { + if swaggerLoader.visitedLink == nil { + swaggerLoader.visitedLink = make(map[*Link]struct{}) + } if _, ok := swaggerLoader.visitedLink[component.Value]; ok { return nil } @@ -875,10 +860,10 @@ func (swaggerLoader *SwaggerLoader) resolvePathItemRef(swagger *Swagger, entrypo key = documentPath.EscapedPath() } key += entrypoint - if _, ok := swaggerLoader.visitedFiles[key]; ok { + if _, ok := swaggerLoader.visitedPathItemRefs[key]; ok { return nil } - swaggerLoader.visitedFiles[key] = struct{}{} + swaggerLoader.visitedPathItemRefs[key] = struct{}{} const prefix = "#/paths/" if pathItem == nil { diff --git a/openapi3/swagger_loader_read_from_uri_func_test.go b/openapi3/swagger_loader_read_from_uri_func_test.go index 3ac0ed74f..b15767855 100644 --- a/openapi3/swagger_loader_read_from_uri_func_test.go +++ b/openapi3/swagger_loader_read_from_uri_func_test.go @@ -1,6 +1,7 @@ package openapi3 import ( + "fmt" "io/ioutil" "net/url" "path/filepath" @@ -21,3 +22,52 @@ func TestLoaderReadFromURIFunc(t *testing.T) { require.NoError(t, doc.Validate(loader.Context)) require.Equal(t, "bar", doc.Paths["/foo"].Get.Responses.Get(200).Value.Content.Get("application/json").Schema.Value.Properties["foo"].Value.Properties["bar"].Value.Items.Value.Example) } + +type multipleSourceSwaggerLoaderExample struct { + Sources map[string][]byte +} + +func (l *multipleSourceSwaggerLoaderExample) LoadSwaggerFromURI( + loader *SwaggerLoader, + location *url.URL, +) ([]byte, error) { + source := l.resolveSourceFromURI(location) + if source == nil { + return nil, fmt.Errorf("Unsupported URI: %q", location.String()) + } + return source, nil +} + +func (l *multipleSourceSwaggerLoaderExample) resolveSourceFromURI(location fmt.Stringer) []byte { + return l.Sources[location.String()] +} + +func TestResolveSchemaExternalRef(t *testing.T) { + rootLocation := &url.URL{Scheme: "http", Host: "example.com", Path: "spec.json"} + externalLocation := &url.URL{Scheme: "http", Host: "example.com", Path: "external.json"} + rootSpec := []byte(fmt.Sprintf( + `{"openapi":"3.0.0","info":{"title":"MyAPI","version":"0.1","description":"An API"},"paths":{},"components":{"schemas":{"Root":{"allOf":[{"$ref":"%s#/components/schemas/External"}]}}}}`, + externalLocation.String(), + )) + externalSpec := []byte(`{"openapi":"3.0.0","info":{"title":"MyAPI","version":"0.1","description":"External Spec"},"paths":{},"components":{"schemas":{"External":{"type":"string"}}}}`) + multipleSourceLoader := &multipleSourceSwaggerLoaderExample{ + Sources: map[string][]byte{ + rootLocation.String(): rootSpec, + externalLocation.String(): externalSpec, + }, + } + loader := &SwaggerLoader{ + IsExternalRefsAllowed: true, + ReadFromURIFunc: multipleSourceLoader.LoadSwaggerFromURI, + } + + doc, err := loader.LoadSwaggerFromURI(rootLocation) + require.NoError(t, err) + + err = doc.Validate(loader.Context) + require.NoError(t, err) + + refRootVisited := doc.Components.Schemas["Root"].Value.AllOf[0] + require.Equal(t, fmt.Sprintf("%s#/components/schemas/External", externalLocation.String()), refRootVisited.Ref) + require.NotNil(t, refRootVisited.Value) +} diff --git a/openapi3/swagger_loader_test.go b/openapi3/swagger_loader_test.go index 578932039..898ae9fe4 100644 --- a/openapi3/swagger_loader_test.go +++ b/openapi3/swagger_loader_test.go @@ -132,70 +132,6 @@ paths: require.Equal(t, example.Value.Value.(map[string]interface{})["error"].(bool), false) } -type sourceExample struct { - Location *url.URL - Spec []byte -} - -type multipleSourceSwaggerLoaderExample struct { - Sources []*sourceExample -} - -func (l *multipleSourceSwaggerLoaderExample) LoadSwaggerFromURI( - loader *SwaggerLoader, - location *url.URL, -) (*Swagger, error) { - source := l.resolveSourceFromURI(location) - if source == nil { - return nil, fmt.Errorf("Unsupported URI: '%s'", location.String()) - } - return loader.LoadSwaggerFromData(source.Spec) -} - -func (l *multipleSourceSwaggerLoaderExample) resolveSourceFromURI(location fmt.Stringer) *sourceExample { - locationString := location.String() - for _, v := range l.Sources { - if v.Location.String() == locationString { - return v - } - } - return nil -} - -func TestResolveSchemaExternalRef(t *testing.T) { - rootLocation := &url.URL{Scheme: "http", Host: "example.com", Path: "spec.json"} - externalLocation := &url.URL{Scheme: "http", Host: "example.com", Path: "external.json"} - rootSpec := []byte(fmt.Sprintf( - `{"openapi":"3.0.0","info":{"title":"MyAPI","version":"0.1","description":"An API"},"paths":{},"components":{"schemas":{"Root":{"allOf":[{"$ref":"%s#/components/schemas/External"}]}}}}`, - externalLocation.String(), - )) - externalSpec := []byte(`{"openapi":"3.0.0","info":{"title":"MyAPI","version":"0.1","description":"External Spec"},"paths":{},"components":{"schemas":{"External":{"type":"string"}}}}`) - multipleSourceLoader := &multipleSourceSwaggerLoaderExample{ - Sources: []*sourceExample{ - { - Location: rootLocation, - Spec: rootSpec, - }, - { - Location: externalLocation, - Spec: externalSpec, - }, - }, - } - loader := &SwaggerLoader{ - IsExternalRefsAllowed: true, - LoadSwaggerFromURIFunc: multipleSourceLoader.LoadSwaggerFromURI, - } - doc, err := loader.LoadSwaggerFromURI(rootLocation) - require.NoError(t, err) - err = doc.Validate(loader.Context) - - require.NoError(t, err) - refRootVisited := doc.Components.Schemas["Root"].Value.AllOf[0] - require.Equal(t, fmt.Sprintf("%s#/components/schemas/External", externalLocation.String()), refRootVisited.Ref) - require.NotNil(t, refRootVisited.Value) -} - func TestLoadErrorOnRefMisuse(t *testing.T) { spec := []byte(` openapi: '3.0.0'