Skip to content

Commit

Permalink
Support reference cycles (#393)
Browse files Browse the repository at this point in the history
  • Loading branch information
DerekStrickland committed Aug 3, 2021
1 parent ed98f50 commit 7fd2ca1
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 5 deletions.
60 changes: 56 additions & 4 deletions openapi3gen/openapi3gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openapi3gen

import (
"encoding/json"
"fmt"
"math"
"reflect"
"strings"
Expand All @@ -22,6 +23,7 @@ type Option func(*generatorOpt)

type generatorOpt struct {
useAllExportedFields bool
throwErrorOnCycle bool
}

// UseAllExportedFields changes the default behavior of only
Expand All @@ -30,6 +32,12 @@ func UseAllExportedFields() Option {
return func(x *generatorOpt) { x.useAllExportedFields = true }
}

// ThrowErrorOnCycle changes the default behavior of creating cycle
// refs to instead error if a cycle is detected.
func ThrowErrorOnCycle() Option {
return func(x *generatorOpt) { x.throwErrorOnCycle = true }
}

// NewSchemaRefForValue uses reflection on the given value to produce a SchemaRef.
func NewSchemaRefForValue(value interface{}, opts ...Option) (*openapi3.SchemaRef, map[*openapi3.SchemaRef]int, error) {
g := NewGenerator(opts...)
Expand Down Expand Up @@ -104,6 +112,10 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
if a && b {
vs, err := g.generateSchemaRefFor(parents, v.Type)
if err != nil {
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
g.SchemaRefs[vs]++
return vs, nil
}
return nil, err
}
refSchemaRef := RefSchemaRef
Expand Down Expand Up @@ -185,7 +197,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
schema.Type = "array"
items, err := g.generateSchemaRefFor(parents, t.Elem())
if err != nil {
return nil, err
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
items = g.generateCycleSchemaRef(t.Elem(), schema)
} else {
return nil, err
}
}
if items != nil {
g.SchemaRefs[items]++
Expand All @@ -197,7 +213,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
schema.Type = "object"
additionalProperties, err := g.generateSchemaRefFor(parents, t.Elem())
if err != nil {
return nil, err
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
additionalProperties = g.generateCycleSchemaRef(t.Elem(), schema)
} else {
return nil, err
}
}
if additionalProperties != nil {
g.SchemaRefs[additionalProperties]++
Expand All @@ -221,7 +241,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
if t.Field(fieldInfo.Index[0]).Anonymous {
ref, err := g.generateSchemaRefFor(parents, fType)
if err != nil {
return nil, err
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
ref = g.generateCycleSchemaRef(fType, schema)
} else {
return nil, err
}
}
if ref != nil {
g.SchemaRefs[ref]++
Expand All @@ -237,7 +261,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec

ref, err := g.generateSchemaRefFor(parents, fType)
if err != nil {
return nil, err
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
ref = g.generateCycleSchemaRef(fType, schema)
} else {
return nil, err
}
}
if ref != nil {
g.SchemaRefs[ref]++
Expand All @@ -255,6 +283,30 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
return openapi3.NewSchemaRef(t.Name(), schema), nil
}

func (g *Generator) generateCycleSchemaRef(t reflect.Type, schema *openapi3.Schema) *openapi3.SchemaRef {
var typeName string
switch t.Kind() {
case reflect.Ptr:
return g.generateCycleSchemaRef(t.Elem(), schema)
case reflect.Slice:
ref := g.generateCycleSchemaRef(t.Elem(), schema)
sliceSchema := openapi3.NewSchema()
sliceSchema.Type = "array"
sliceSchema.Items = ref
return openapi3.NewSchemaRef("", sliceSchema)
case reflect.Map:
ref := g.generateCycleSchemaRef(t.Elem(), schema)
mapSchema := openapi3.NewSchema()
mapSchema.Type = "object"
mapSchema.AdditionalProperties = ref
return openapi3.NewSchemaRef("", mapSchema)
default:
typeName = t.Name()
}

return openapi3.NewSchemaRef(fmt.Sprintf("#/components/schemas/%s", typeName), schema)
}

var RefSchemaRef = openapi3.NewSchemaRef("Ref",
openapi3.NewObjectSchema().WithProperty("$ref", openapi3.NewStringSchema().WithMinLength(1)))

Expand Down
32 changes: 31 additions & 1 deletion openapi3gen/openapi3gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type CyclicType1 struct {
}

func TestCyclic(t *testing.T) {
schemaRef, refsMap, err := NewSchemaRefForValue(&CyclicType0{})
schemaRef, refsMap, err := NewSchemaRefForValue(&CyclicType0{}, ThrowErrorOnCycle())
require.IsType(t, &CycleError{}, err)
require.Nil(t, schemaRef)
require.Empty(t, refsMap)
Expand Down Expand Up @@ -84,3 +84,33 @@ func TestEmbeddedStructs(t *testing.T) {
_, ok = schemaRef.Value.Properties["ID"]
require.Equal(t, true, ok)
}

func TestCyclicReferences(t *testing.T) {
type ObjectDiff struct {
FieldCycle *ObjectDiff
SliceCycle []*ObjectDiff
MapCycle map[*ObjectDiff]*ObjectDiff
}

instance := &ObjectDiff{
FieldCycle: nil,
SliceCycle: nil,
MapCycle: nil,
}

generator := NewGenerator(UseAllExportedFields())

schemaRef, err := generator.GenerateSchemaRef(reflect.TypeOf(instance))
require.NoError(t, err)

require.NotNil(t, schemaRef.Value.Properties["FieldCycle"])
require.Equal(t, "#/components/schemas/ObjectDiff", schemaRef.Value.Properties["FieldCycle"].Ref)

require.NotNil(t, schemaRef.Value.Properties["SliceCycle"])
require.Equal(t, "array", schemaRef.Value.Properties["SliceCycle"].Value.Type)
require.Equal(t, "#/components/schemas/ObjectDiff", schemaRef.Value.Properties["SliceCycle"].Value.Items.Ref)

require.NotNil(t, schemaRef.Value.Properties["MapCycle"])
require.Equal(t, "object", schemaRef.Value.Properties["MapCycle"].Value.Type)
require.Equal(t, "#/components/schemas/ObjectDiff", schemaRef.Value.Properties["MapCycle"].Value.AdditionalProperties.Ref)
}

0 comments on commit 7fd2ca1

Please sign in to comment.