diff --git a/openapi3gen/openapi3gen.go b/openapi3gen/openapi3gen.go index 84d1f998d..7c321fe7a 100644 --- a/openapi3gen/openapi3gen.go +++ b/openapi3gen/openapi3gen.go @@ -3,6 +3,7 @@ package openapi3gen import ( "encoding/json" + "fmt" "math" "reflect" "strings" @@ -22,6 +23,7 @@ type Option func(*generatorOpt) type generatorOpt struct { useAllExportedFields bool + throwErrorOnCycle bool } // UseAllExportedFields changes the default behavior of only @@ -30,6 +32,12 @@ func UseAllExportedFields() Option { return func(x *generatorOpt) { x.useAllExportedFields = true } } +// ThrowErrorOnCycle changes the default behavior of creating cycle +// refs to instead error if a cycle is detected. +func ThrowErrorOnCycle() Option { + return func(x *generatorOpt) { x.throwErrorOnCycle = true } +} + // NewSchemaRefForValue uses reflection on the given value to produce a SchemaRef. func NewSchemaRefForValue(value interface{}, opts ...Option) (*openapi3.SchemaRef, map[*openapi3.SchemaRef]int, error) { g := NewGenerator(opts...) @@ -104,6 +112,10 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec if a && b { vs, err := g.generateSchemaRefFor(parents, v.Type) if err != nil { + if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle { + g.SchemaRefs[vs]++ + return vs, nil + } return nil, err } refSchemaRef := RefSchemaRef @@ -185,7 +197,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec schema.Type = "array" items, err := g.generateSchemaRefFor(parents, t.Elem()) if err != nil { - return nil, err + if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle { + items = g.generateCycleSchemaRef(t.Elem(), schema) + } else { + return nil, err + } } if items != nil { g.SchemaRefs[items]++ @@ -197,7 +213,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec schema.Type = "object" additionalProperties, err := g.generateSchemaRefFor(parents, t.Elem()) if err != nil { - return nil, err + if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle { + additionalProperties = g.generateCycleSchemaRef(t.Elem(), schema) + } else { + return nil, err + } } if additionalProperties != nil { g.SchemaRefs[additionalProperties]++ @@ -221,7 +241,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec if t.Field(fieldInfo.Index[0]).Anonymous { ref, err := g.generateSchemaRefFor(parents, fType) if err != nil { - return nil, err + if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle { + ref = g.generateCycleSchemaRef(fType, schema) + } else { + return nil, err + } } if ref != nil { g.SchemaRefs[ref]++ @@ -237,7 +261,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec ref, err := g.generateSchemaRefFor(parents, fType) if err != nil { - return nil, err + if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle { + ref = g.generateCycleSchemaRef(fType, schema) + } else { + return nil, err + } } if ref != nil { g.SchemaRefs[ref]++ @@ -255,6 +283,30 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec return openapi3.NewSchemaRef(t.Name(), schema), nil } +func (g *Generator) generateCycleSchemaRef(t reflect.Type, schema *openapi3.Schema) *openapi3.SchemaRef { + var typeName string + switch t.Kind() { + case reflect.Ptr: + return g.generateCycleSchemaRef(t.Elem(), schema) + case reflect.Slice: + ref := g.generateCycleSchemaRef(t.Elem(), schema) + sliceSchema := openapi3.NewSchema() + sliceSchema.Type = "array" + sliceSchema.Items = ref + return openapi3.NewSchemaRef("", sliceSchema) + case reflect.Map: + ref := g.generateCycleSchemaRef(t.Elem(), schema) + mapSchema := openapi3.NewSchema() + mapSchema.Type = "object" + mapSchema.AdditionalProperties = ref + return openapi3.NewSchemaRef("", mapSchema) + default: + typeName = t.Name() + } + + return openapi3.NewSchemaRef(fmt.Sprintf("#/components/schemas/%s", typeName), schema) +} + var RefSchemaRef = openapi3.NewSchemaRef("Ref", openapi3.NewObjectSchema().WithProperty("$ref", openapi3.NewStringSchema().WithMinLength(1))) diff --git a/openapi3gen/openapi3gen_test.go b/openapi3gen/openapi3gen_test.go index f8e2430e2..0062e2e5f 100644 --- a/openapi3gen/openapi3gen_test.go +++ b/openapi3gen/openapi3gen_test.go @@ -16,7 +16,7 @@ type CyclicType1 struct { } func TestCyclic(t *testing.T) { - schemaRef, refsMap, err := NewSchemaRefForValue(&CyclicType0{}) + schemaRef, refsMap, err := NewSchemaRefForValue(&CyclicType0{}, ThrowErrorOnCycle()) require.IsType(t, &CycleError{}, err) require.Nil(t, schemaRef) require.Empty(t, refsMap) @@ -84,3 +84,33 @@ func TestEmbeddedStructs(t *testing.T) { _, ok = schemaRef.Value.Properties["ID"] require.Equal(t, true, ok) } + +func TestCyclicReferences(t *testing.T) { + type ObjectDiff struct { + FieldCycle *ObjectDiff + SliceCycle []*ObjectDiff + MapCycle map[*ObjectDiff]*ObjectDiff + } + + instance := &ObjectDiff{ + FieldCycle: nil, + SliceCycle: nil, + MapCycle: nil, + } + + generator := NewGenerator(UseAllExportedFields()) + + schemaRef, err := generator.GenerateSchemaRef(reflect.TypeOf(instance)) + require.NoError(t, err) + + require.NotNil(t, schemaRef.Value.Properties["FieldCycle"]) + require.Equal(t, "#/components/schemas/ObjectDiff", schemaRef.Value.Properties["FieldCycle"].Ref) + + require.NotNil(t, schemaRef.Value.Properties["SliceCycle"]) + require.Equal(t, "array", schemaRef.Value.Properties["SliceCycle"].Value.Type) + require.Equal(t, "#/components/schemas/ObjectDiff", schemaRef.Value.Properties["SliceCycle"].Value.Items.Ref) + + require.NotNil(t, schemaRef.Value.Properties["MapCycle"]) + require.Equal(t, "object", schemaRef.Value.Properties["MapCycle"].Value.Type) + require.Equal(t, "#/components/schemas/ObjectDiff", schemaRef.Value.Properties["MapCycle"].Value.AdditionalProperties.Ref) +}