Skip to content

Commit

Permalink
Fix defaulter for nested recursive types
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhita committed Jun 11, 2017
1 parent c79c13d commit ba7ac3b
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions examples/defaulter-gen/generators/defaulter.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,10 @@ func Packages(context *generator.Context, arguments *args.GeneratorArgs) generat
if d.object != nil {
continue
}
if buildCallTreeForType(t, true, existingDefaulters, newDefaulters) != nil {
// existingTypes keeps track of types that have already been visited in the tree.
// This is used to avoid recursion for recursive types.
existingTypes := make(map[*types.Type]bool)
if buildCallTreeForType(t, true, existingDefaulters, newDefaulters, existingTypes) != nil {
args := defaultingArgsFromType(t)
sw.Do("$.inType|objectdefaultfn$", args)
newDefaulters[t] = defaults{
Expand Down Expand Up @@ -396,7 +399,7 @@ func Packages(context *generator.Context, arguments *args.GeneratorArgs) generat
// that could be or will be generated. If newDefaulters has an entry for a type, but the 'object' field is nil,
// this function skips adding that defaulter - this allows us to avoid generating object defaulter functions for
// list types that call empty defaulters.
func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefaulters defaulterFuncMap) *callNode {
func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefaulters defaulterFuncMap, existingTypes map[*types.Type]bool) *callNode {
parent := &callNode{}

if root {
Expand Down Expand Up @@ -432,19 +435,27 @@ func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefau
// base has been added already, now add any additional defaulters defined for this object
parent.call = append(parent.call, defaults.additional...)

// if the type already exists, don't build the tree for it and don't generate anything.
// This is used to avoid recursion for nested recursive types.
if existingTypes[t] {
return nil
}
// if type doesn't exist, mark it as existing
existingTypes[t] = true

switch t.Kind {
case types.Pointer:
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters); child != nil {
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters, existingTypes); child != nil {
child.elem = true
parent.children = append(parent.children, *child)
}
case types.Slice, types.Array:
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters); child != nil {
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters, existingTypes); child != nil {
child.index = true
parent.children = append(parent.children, *child)
}
case types.Map:
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters); child != nil {
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters, existingTypes); child != nil {
child.key = true
parent.children = append(parent.children, *child)
}
Expand All @@ -458,20 +469,24 @@ func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefau
name = field.Type.Name.Name
}
}
if child := buildCallTreeForType(field.Type, false, existingDefaulters, newDefaulters); child != nil {
if child := buildCallTreeForType(field.Type, false, existingDefaulters, newDefaulters, existingTypes); child != nil {
child.field = name
parent.children = append(parent.children, *child)
}
}
case types.Alias:
if child := buildCallTreeForType(t.Underlying, false, existingDefaulters, newDefaulters); child != nil {
if child := buildCallTreeForType(t.Underlying, false, existingDefaulters, newDefaulters, existingTypes); child != nil {
parent.children = append(parent.children, *child)
}
}
if len(parent.children) == 0 && len(parent.call) == 0 {
//glog.V(6).Infof("decided type %s needs no generation", t.Name)
return nil
}

// The type now acts as a parent, not a nested recursive type.
// We can now build the tree for it safely.
existingTypes[t] = false
return parent
}

Expand Down Expand Up @@ -571,7 +586,8 @@ func (g *genDefaulter) GenerateType(c *generator.Context, t *types.Type, w io.Wr

glog.V(5).Infof("generating for type %v", t)

callTree := buildCallTreeForType(t, true, g.existingDefaulters, g.newDefaulters)
existingTypes := make(map[*types.Type]bool)
callTree := buildCallTreeForType(t, true, g.existingDefaulters, g.newDefaulters, existingTypes)
if callTree == nil {
glog.V(5).Infof(" no defaulters defined")
return nil
Expand Down

0 comments on commit ba7ac3b

Please sign in to comment.