From 2692f43ba21c89366b2a221a86be520b87539352 Mon Sep 17 00:00:00 2001 From: Pierre Fenoll Date: Fri, 5 Jul 2024 17:57:58 +0200 Subject: [PATCH] openapi3: allow YAML-marshaling invalid specs (#977) * openapi3: allow YAML-marshaling invalid specs Signed-off-by: Pierre Fenoll * fixes Signed-off-by: Pierre Fenoll --------- Signed-off-by: Pierre Fenoll --- .github/docs/openapi3.txt | 2 +- maps.sh | 7 ++++ openapi3/info.go | 5 ++- openapi3/issue972_test.go | 66 +++++++++++++++++++++++++++++++++ openapi3/loader.go | 3 +- openapi3/maplike.go | 9 +++++ openapi3/openapi3.go | 3 ++ openapi3/schema_formats_test.go | 4 +- 8 files changed, 93 insertions(+), 6 deletions(-) create mode 100644 openapi3/issue972_test.go diff --git a/.github/docs/openapi3.txt b/.github/docs/openapi3.txt index e0e4c70c3..3a7caf2c8 100644 --- a/.github/docs/openapi3.txt +++ b/.github/docs/openapi3.txt @@ -645,7 +645,7 @@ type Info struct { func (info Info) MarshalJSON() ([]byte, error) MarshalJSON returns the JSON encoding of Info. -func (info Info) MarshalYAML() (any, error) +func (info *Info) MarshalYAML() (any, error) MarshalYAML returns the YAML encoding of Info. func (info *Info) UnmarshalJSON(data []byte) error diff --git a/maps.sh b/maps.sh index b23c36019..9cfd0ffdc 100755 --- a/maps.sh +++ b/maps.sh @@ -155,9 +155,16 @@ EOF maplike_UnMarsh() { + if [[ "$type" != '*'* ]]; then + echo "TODO: impl non-pointer receiver YAML Marshaler" + exit 2 + fi cat <>"$maplike" // MarshalYAML returns the YAML encoding of ${type#'*'}. func (${name} ${type}) MarshalYAML() (any, error) { + if ${name} == nil { + return nil, nil + } m := make(map[string]any, ${name}.Len()+len(${name}.Extensions)) for k, v := range ${name}.Extensions { m[k] = v diff --git a/openapi3/info.go b/openapi3/info.go index 7326bcc0d..e2468285c 100644 --- a/openapi3/info.go +++ b/openapi3/info.go @@ -29,7 +29,10 @@ func (info Info) MarshalJSON() ([]byte, error) { } // MarshalYAML returns the YAML encoding of Info. -func (info Info) MarshalYAML() (any, error) { +func (info *Info) MarshalYAML() (any, error) { + if info == nil { + return nil, nil + } m := make(map[string]any, 6+len(info.Extensions)) for k, v := range info.Extensions { m[k] = v diff --git a/openapi3/issue972_test.go b/openapi3/issue972_test.go new file mode 100644 index 000000000..3575adc9d --- /dev/null +++ b/openapi3/issue972_test.go @@ -0,0 +1,66 @@ +package openapi3 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" +) + +func TestIssue972(t *testing.T) { + type testcase struct { + spec string + validationErrorContains string + } + + base := ` +openapi: 3.0.2 +paths: {} +components: {} +` + + for _, tc := range []testcase{{ + spec: base, + validationErrorContains: "invalid info: must be an object", + }, { + spec: base + ` +info: + title: "Hello World REST APIs" + version: "1.0" +`, + }, { + spec: base + ` +info: null +`, + validationErrorContains: "invalid info: must be an object", + }, { + spec: base + ` +info: {} +`, + validationErrorContains: "invalid info: value of version must be a non-empty string", + }, { + spec: base + ` +info: + title: "Hello World REST APIs" +`, + validationErrorContains: "invalid info: value of version must be a non-empty string", + }} { + t.Logf("spec: %s", tc.spec) + + loader := &Loader{} + doc, err := loader.LoadFromData([]byte(tc.spec)) + assert.NoError(t, err) + assert.NotNil(t, doc) + + err = doc.Validate(loader.Context) + if e := tc.validationErrorContains; e != "" { + assert.ErrorContains(t, err, e) + } else { + assert.NoError(t, err) + } + + txt, err := yaml.Marshal(doc) + assert.NoError(t, err) + assert.NotEmpty(t, txt) + } +} diff --git a/openapi3/loader.go b/openapi3/loader.go index 567d7394e..4f2766a0f 100644 --- a/openapi3/loader.go +++ b/openapi3/loader.go @@ -1140,8 +1140,7 @@ func (loader *Loader) resolvePathItemRef(doc *T, pathItem *PathItem, documentPat *pathItem = p } else { var resolved PathItem - doc, documentPath, err = loader.resolveComponent(doc, ref, documentPath, &resolved) - if err != nil { + if doc, documentPath, err = loader.resolveComponent(doc, ref, documentPath, &resolved); err != nil { if err == errMUSTPathItem { return nil } diff --git a/openapi3/maplike.go b/openapi3/maplike.go index 2085b7e2b..7b8045c67 100644 --- a/openapi3/maplike.go +++ b/openapi3/maplike.go @@ -78,6 +78,9 @@ func (responses Responses) JSONLookup(token string) (any, error) { // MarshalYAML returns the YAML encoding of Responses. func (responses *Responses) MarshalYAML() (any, error) { + if responses == nil { + return nil, nil + } m := make(map[string]any, responses.Len()+len(responses.Extensions)) for k, v := range responses.Extensions { m[k] = v @@ -206,6 +209,9 @@ func (callback Callback) JSONLookup(token string) (any, error) { // MarshalYAML returns the YAML encoding of Callback. func (callback *Callback) MarshalYAML() (any, error) { + if callback == nil { + return nil, nil + } m := make(map[string]any, callback.Len()+len(callback.Extensions)) for k, v := range callback.Extensions { m[k] = v @@ -334,6 +340,9 @@ func (paths Paths) JSONLookup(token string) (any, error) { // MarshalYAML returns the YAML encoding of Paths. func (paths *Paths) MarshalYAML() (any, error) { + if paths == nil { + return nil, nil + } m := make(map[string]any, paths.Len()+len(paths.Extensions)) for k, v := range paths.Extensions { m[k] = v diff --git a/openapi3/openapi3.go b/openapi3/openapi3.go index f8228012f..ef1592e8c 100644 --- a/openapi3/openapi3.go +++ b/openapi3/openapi3.go @@ -66,6 +66,9 @@ func (doc *T) MarshalJSON() ([]byte, error) { // MarshalYAML returns the YAML encoding of T. func (doc *T) MarshalYAML() (any, error) { + if doc == nil { + return nil, nil + } m := make(map[string]any, 4+len(doc.Extensions)) for k, v := range doc.Extensions { m[k] = v diff --git a/openapi3/schema_formats_test.go b/openapi3/schema_formats_test.go index 09cf3bfe5..f7c6a08cb 100644 --- a/openapi3/schema_formats_test.go +++ b/openapi3/schema_formats_test.go @@ -169,13 +169,13 @@ func TestNumberFormats(t *testing.T) { } DefineNumberFormatValidator("lessThan10", NewCallbackValidator(func(value float64) error { if value >= 10 { - return fmt.Errorf("not less than 10") + return errors.New("not less than 10") } return nil })) DefineIntegerFormatValidator("odd", NewCallbackValidator(func(value int64) error { if value%2 == 0 { - return fmt.Errorf("not odd") + return errors.New("not odd") } return nil }))