diff --git a/openapi3gen/openapi3gen.go b/openapi3gen/openapi3gen.go index b4ae7b04c..1c233be12 100644 --- a/openapi3gen/openapi3gen.go +++ b/openapi3gen/openapi3gen.go @@ -52,14 +52,10 @@ func SchemaCustomizer(sc SchemaCustomizerFn) Option { return func(x *generatorOpt) { x.schemaCustomizer = sc } } -// NewSchemaRefForValue uses reflection on the given value to produce a SchemaRef. -func NewSchemaRefForValue(value interface{}, opts ...Option) (*openapi3.SchemaRef, map[*openapi3.SchemaRef]int, error) { +// NewSchemaRefForValue uses reflection on the given value to produce a SchemaRef, and updates a supplied map with any dependent component schemas (for cycles) +func NewSchemaRefForValue(value interface{}, schemas openapi3.Schemas, opts ...Option) (*openapi3.SchemaRef, error) { g := NewGenerator(opts...) - ref, err := g.GenerateSchemaRef(reflect.TypeOf(value)) - for ref := range g.SchemaRefs { - ref.Ref = "" - } - return ref, g.SchemaRefs, err + return g.newSchemaRefForValue(value, schemas) } type Generator struct { @@ -71,6 +67,9 @@ type Generator struct { // If count is 1, it's not ne // An OpenAPI identifier has been assigned to each. SchemaRefs map[*openapi3.SchemaRef]int + + // ComponentSchemas contains a map of schemas that must be defined in the components, due to cycles + ComponentSchemas map[string]bool } func NewGenerator(opts ...Option) *Generator { @@ -79,9 +78,10 @@ func NewGenerator(opts ...Option) *Generator { f(gOpt) } return &Generator{ - Types: make(map[reflect.Type]*openapi3.SchemaRef), - SchemaRefs: make(map[*openapi3.SchemaRef]int), - opts: *gOpt, + Types: make(map[reflect.Type]*openapi3.SchemaRef), + SchemaRefs: make(map[*openapi3.SchemaRef]int), + ComponentSchemas: make(map[string]bool), + opts: *gOpt, } } @@ -90,6 +90,23 @@ func (g *Generator) GenerateSchemaRef(t reflect.Type) (*openapi3.SchemaRef, erro return g.generateSchemaRefFor(nil, t, "_root", "") } +func (g *Generator) newSchemaRefForValue(value interface{}, schemas openapi3.Schemas) (*openapi3.SchemaRef, error) { + ref, err := g.GenerateSchemaRef(reflect.TypeOf(value)) + for ref := range g.SchemaRefs { + if g.ComponentSchemas[ref.Ref] && schemas != nil { + schemas[ref.Ref] = &openapi3.SchemaRef{ + Value: ref.Value, + } + } + if strings.HasPrefix(ref.Ref, "#/components/schemas/") { + ref.Value = nil + } else { + ref.Ref = "" + } + } + return ref, err +} + func (g *Generator) generateSchemaRefFor(parents []*jsoninfo.TypeInfo, t reflect.Type, name string, tag reflect.StructTag) (*openapi3.SchemaRef, error) { if ref := g.Types[t]; ref != nil && g.opts.schemaCustomizer == nil { g.SchemaRefs[ref]++ @@ -341,6 +358,7 @@ func (g *Generator) generateCycleSchemaRef(t reflect.Type, schema *openapi3.Sche typeName = t.Name() } + g.ComponentSchemas[typeName] = true return openapi3.NewSchemaRef(fmt.Sprintf("#/components/schemas/%s", typeName), schema) } diff --git a/openapi3gen/openapi3gen_test.go b/openapi3gen/openapi3gen_test.go index 6d96db98e..ee8530c8f 100644 --- a/openapi3gen/openapi3gen_test.go +++ b/openapi3gen/openapi3gen_test.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/getkin/kin-openapi/openapi3" "github.com/stretchr/testify/require" @@ -19,11 +20,106 @@ type CyclicType1 struct { CyclicField *CyclicType0 `json:"b"` } +func TestSimpleStruct(t *testing.T) { + type SomeOtherType string + + type SomeStruct struct { + Bool bool `json:"bool"` + Int int `json:"int"` + Int64 int64 `json:"int64"` + Float64 float64 `json:"float64"` + String string `json:"string"` + Bytes []byte `json:"bytes"` + JSON json.RawMessage `json:"json"` + Time time.Time `json:"time"` + Slice []SomeOtherType `json:"slice"` + Map map[string]*SomeOtherType `json:"map"` + + Struct struct { + X string `json:"x"` + } `json:"struct"` + + EmptyStruct struct { + Y string + } `json:"structWithoutFields"` + + Ptr *SomeOtherType `json:"ptr"` + } + + g := NewGenerator() + schemaRef, err := g.newSchemaRefForValue(&SomeStruct{}, nil) + require.NoError(t, err) + require.Len(t, g.SchemaRefs, 15) + + schemaJSON, err := json.Marshal(schemaRef) + require.NoError(t, err) + + require.JSONEq(t, ` + { + "properties": { + "bool": { + "type": "boolean" + }, + "bytes": { + "format": "byte", + "type": "string" + }, + "float64": { + "format": "double", + "type": "number" + }, + "int": { + "type": "integer" + }, + "int64": { + "format": "int64", + "type": "integer" + }, + "json": {}, + "map": { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + "ptr": { + "type": "string" + }, + "slice": { + "items": { + "type": "string" + }, + "type": "array" + }, + "string": { + "type": "string" + }, + "struct": { + "properties": { + "x": { + "type": "string" + } + }, + "type": "object" + }, + "structWithoutFields": {}, + "time": { + "format": "date-time", + "type": "string" + } + }, + "type": "object" + } + `, string(schemaJSON)) + +} + func TestCyclic(t *testing.T) { - schemaRef, refsMap, err := NewSchemaRefForValue(&CyclicType0{}, ThrowErrorOnCycle()) + g := NewGenerator(ThrowErrorOnCycle()) + schemaRef, err := g.newSchemaRefForValue(&CyclicType0{}, nil) require.IsType(t, &CycleError{}, err) require.Nil(t, schemaRef) - require.Empty(t, refsMap) + require.Empty(t, g.SchemaRefs) } func TestExportedNonTagged(t *testing.T) { @@ -34,7 +130,7 @@ func TestExportedNonTagged(t *testing.T) { EvenAYaml string `yaml:"even_a_yaml"` } - schemaRef, _, err := NewSchemaRefForValue(&Bla{}, UseAllExportedFields()) + schemaRef, err := NewSchemaRefForValue(&Bla{}, nil, UseAllExportedFields()) require.NoError(t, err) require.Equal(t, &openapi3.SchemaRef{Value: &openapi3.Schema{ Type: "object", @@ -50,7 +146,7 @@ func TestExportUint(t *testing.T) { UnsignedInt uint `json:"uint"` } - schemaRef, _, err := NewSchemaRefForValue(&UnsignedIntStruct{}, UseAllExportedFields()) + schemaRef, err := NewSchemaRefForValue(&UnsignedIntStruct{}, nil, UseAllExportedFields()) require.NoError(t, err) require.Equal(t, &openapi3.SchemaRef{Value: &openapi3.Schema{ Type: "object", @@ -169,7 +265,7 @@ func TestSchemaCustomizer(t *testing.T) { EnumField3 string `json:"enum3" myenumtag:"e,f"` } - schemaRef, _, err := NewSchemaRefForValue(&Bla{}, UseAllExportedFields(), SchemaCustomizer(func(name string, ft reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error { + schemaRef, err := NewSchemaRefForValue(&Bla{}, nil, UseAllExportedFields(), SchemaCustomizer(func(name string, ft reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error { t.Logf("Field=%s,Tag=%s", name, tag) if tag.Get("mymintag") != "" { minVal, err := strconv.ParseFloat(tag.Get("mymintag"), 64) @@ -241,8 +337,73 @@ func TestSchemaCustomizer(t *testing.T) { func TestSchemaCustomizerError(t *testing.T) { type Bla struct{} - _, _, err := NewSchemaRefForValue(&Bla{}, UseAllExportedFields(), SchemaCustomizer(func(name string, ft reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error { + _, err := NewSchemaRefForValue(&Bla{}, nil, UseAllExportedFields(), SchemaCustomizer(func(name string, ft reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error { return errors.New("test error") })) require.EqualError(t, err, "test error") } + +func TestRecursiveSchema(t *testing.T) { + + type RecursiveType struct { + Field1 string `json:"field1"` + Field2 string `json:"field2"` + Field3 string `json:"field3"` + Components []*RecursiveType `json:"children,omitempty"` + } + + schemas := make(openapi3.Schemas) + schemaRef, err := NewSchemaRefForValue(&RecursiveType{}, schemas) + require.NoError(t, err) + + jsonSchemas, err := json.MarshalIndent(&schemas, "", " ") + require.NoError(t, err) + + jsonSchemaRef, err := json.MarshalIndent(&schemaRef, "", " ") + require.NoError(t, err) + + require.JSONEq(t, `{ + "RecursiveType": { + "properties": { + "children": { + "items": { + "$ref": "#/components/schemas/RecursiveType" + }, + "type": "array" + }, + "field1": { + "type": "string" + }, + "field2": { + "type": "string" + }, + "field3": { + "type": "string" + } + }, + "type": "object" + } + }`, string(jsonSchemas)) + + require.JSONEq(t, `{ + "properties": { + "children": { + "items": { + "$ref": "#/components/schemas/RecursiveType" + }, + "type": "array" + }, + "field1": { + "type": "string" + }, + "field2": { + "type": "string" + }, + "field3": { + "type": "string" + } + }, + "type": "object" + }`, string(jsonSchemaRef)) + +} diff --git a/openapi3gen/simple_test.go b/openapi3gen/simple_test.go index d997e23b2..99e94ae12 100644 --- a/openapi3gen/simple_test.go +++ b/openapi3gen/simple_test.go @@ -36,15 +36,11 @@ type ( ) func Example() { - schemaRef, refsMap, err := openapi3gen.NewSchemaRefForValue(&SomeStruct{}) + schemaRef, err := openapi3gen.NewSchemaRefForValue(&SomeStruct{}, nil) if err != nil { panic(err) } - if len(refsMap) != 15 { - panic(fmt.Sprintf("unintended len(refsMap) = %d", len(refsMap))) - } - data, err := json.MarshalIndent(schemaRef, "", " ") if err != nil { panic(err)