From d8e33d558e2c5fcd7f9092790780e68adbac0f1b Mon Sep 17 00:00:00 2001 From: Dan Scales Date: Sun, 21 Feb 2021 10:54:38 -0800 Subject: [PATCH] cmd/compile: deal with closures in generic functions and instantiated function values - Deal with closures in generic functions by fixing the stenciling code - Deal with instantiated function values (instantiated generic functions that are not immediately called) during stenciling. This requires changing the OFUNCINST node to an ONAME node for the appropriately instantiated function. We do this in a second pass, since this is uncommon, but requires editing the tree at multiple levels. - Check global assignments (as well as functions) for generic function instantiations. - Fix a bug in (*subst).typ where a generic type in a generic function may definitely not use all the type args of the function, so we need to translate the rparams of the type based on the tparams/targs of the function. - Added new test combine.go that tests out closures in generic functions and instantiated function values. - Added one new variant to the settable test. - Enabling inlining functions with closures for -G=3. (For now, set Ntype on closures in -G=3 mode to keep compatibility with later parts of compiler, and allow inlining of functions with closures.) Change-Id: Iea63d5704c322e42e2f750a83adc8b44f911d4ec Reviewed-on: https://go-review.googlesource.com/c/go/+/296269 Reviewed-by: Robert Griesemer Run-TryBot: Dan Scales TryBot-Result: Go Bot Trust: Dan Scales --- src/cmd/compile/internal/inline/inl.go | 2 +- src/cmd/compile/internal/noder/expr.go | 11 +- src/cmd/compile/internal/noder/stencil.go | 182 +++++++++++++++++----- test/typeparam/combine.go | 65 ++++++++ test/typeparam/settable.go | 27 +++- 5 files changed, 236 insertions(+), 51 deletions(-) create mode 100644 test/typeparam/combine.go diff --git a/src/cmd/compile/internal/inline/inl.go b/src/cmd/compile/internal/inline/inl.go index e961b10844808a..1d049298d724b2 100644 --- a/src/cmd/compile/internal/inline/inl.go +++ b/src/cmd/compile/internal/inline/inl.go @@ -354,7 +354,7 @@ func (v *hairyVisitor) doNode(n ir.Node) bool { return true case ir.OCLOSURE: - if base.Debug.InlFuncsWithClosures == 0 || base.Flag.G > 0 { + if base.Debug.InlFuncsWithClosures == 0 { v.reason = "not inlining functions with closures" return true } diff --git a/src/cmd/compile/internal/noder/expr.go b/src/cmd/compile/internal/noder/expr.go index b166d34eadc9f0..3fded144dcdaf3 100644 --- a/src/cmd/compile/internal/noder/expr.go +++ b/src/cmd/compile/internal/noder/expr.go @@ -325,19 +325,22 @@ func (g *irgen) compLit(typ types2.Type, lit *syntax.CompositeLit) ir.Node { return typecheck.Expr(ir.NewCompLitExpr(g.pos(lit), ir.OCOMPLIT, ir.TypeNode(g.typ(typ)), exprs)) } -func (g *irgen) funcLit(typ types2.Type, expr *syntax.FuncLit) ir.Node { +func (g *irgen) funcLit(typ2 types2.Type, expr *syntax.FuncLit) ir.Node { fn := ir.NewFunc(g.pos(expr)) fn.SetIsHiddenClosure(ir.CurFunc != nil) fn.Nname = ir.NewNameAt(g.pos(expr), typecheck.ClosureName(ir.CurFunc)) ir.MarkFunc(fn.Nname) - fn.Nname.SetType(g.typ(typ)) + typ := g.typ(typ2) fn.Nname.Func = fn fn.Nname.Defn = fn + // Set Ntype for now to be compatible with later parts of compile, remove later. + fn.Nname.Ntype = ir.TypeNode(typ) + typed(typ, fn.Nname) + fn.SetTypecheck(1) fn.OClosure = ir.NewClosureExpr(g.pos(expr), fn) - fn.OClosure.SetType(fn.Nname.Type()) - fn.OClosure.SetTypecheck(1) + typed(typ, fn.OClosure) g.funcBody(fn, nil, expr.Type, expr.Body) diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go index 69461a819055cd..fb1bbfedc8b428 100644 --- a/src/cmd/compile/internal/noder/stencil.go +++ b/src/cmd/compile/internal/noder/stencil.go @@ -27,19 +27,37 @@ func (g *irgen) stencil() { // functions calling other generic functions. for i := 0; i < len(g.target.Decls); i++ { decl := g.target.Decls[i] - if decl.Op() != ir.ODCLFUNC || decl.Type().NumTParams() > 0 { - // Skip any non-function declarations and skip generic functions + + // Look for function instantiations in bodies of non-generic + // functions or in global assignments (ignore global type and + // constant declarations). + switch decl.Op() { + case ir.ODCLFUNC: + if decl.Type().HasTParam() { + // Skip any generic functions + continue + } + + case ir.OAS: + + case ir.OAS2: + + default: continue } - // For each non-generic function, search for any function calls using - // generic function instantiations. (We don't yet handle generic - // function instantiations that are not immediately called.) - // Then create the needed instantiated function if it hasn't been - // created yet, and change to calling that function directly. - f := decl.(*ir.Func) + // For all non-generic code, search for any function calls using + // generic function instantiations. Then create the needed + // instantiated function if it hasn't been created yet, and change + // to calling that function directly. modified := false - ir.VisitList(f.Body, func(n ir.Node) { + foundFuncInst := false + ir.Visit(decl, func(n ir.Node) { + if n.Op() == ir.OFUNCINST { + // We found a function instantiation that is not + // immediately called. + foundFuncInst = true + } if n.Op() != ir.OCALLFUNC || n.(*ir.CallExpr).X.Op() != ir.OFUNCINST { return } @@ -47,19 +65,7 @@ func (g *irgen) stencil() { // instantiation. call := n.(*ir.CallExpr) inst := call.X.(*ir.InstExpr) - sym := makeInstName(inst) - //fmt.Printf("Found generic func call in %v to %v\n", f, s) - st := g.target.Stencils[sym] - if st == nil { - // If instantiation doesn't exist yet, create it and add - // to the list of decls. - st = genericSubst(sym, inst) - g.target.Stencils[sym] = st - g.target.Decls = append(g.target.Decls, st) - if base.Flag.W > 1 { - ir.Dump(fmt.Sprintf("\nstenciled %v", st), st) - } - } + st := g.getInstantiation(inst) // Replace the OFUNCINST with a direct reference to the // new stenciled function call.X = st.Nname @@ -76,6 +82,26 @@ func (g *irgen) stencil() { } modified = true }) + + // If we found an OFUNCINST without a corresponding call in the + // above decl, then traverse the nodes of decl again (with + // EditChildren rather than Visit), where we actually change the + // OFUNCINST node to an ONAME for the instantiated function. + // EditChildren is more expensive than Visit, so we only do this + // in the infrequent case of an OFUNCINSt without a corresponding + // call. + if foundFuncInst { + var edit func(ir.Node) ir.Node + edit = func(x ir.Node) ir.Node { + if x.Op() == ir.OFUNCINST { + st := g.getInstantiation(x.(*ir.InstExpr)) + return st.Nname + } + ir.EditChildren(x, edit) + return x + } + edit(decl) + } if base.Flag.W > 1 && modified { ir.Dump(fmt.Sprintf("\nmodified %v", decl), decl) } @@ -83,18 +109,39 @@ func (g *irgen) stencil() { } -// makeInstName makes the unique name for a stenciled generic function, based on -// the name of the function and the types of the type params. -func makeInstName(inst *ir.InstExpr) *types.Sym { - b := bytes.NewBufferString("#") +// getInstantiation gets the instantiated function corresponding to inst. If the +// instantiated function is not already cached, then it calls genericStub to +// create the new instantiation. +func (g *irgen) getInstantiation(inst *ir.InstExpr) *ir.Func { + var sym *types.Sym if meth, ok := inst.X.(*ir.SelectorExpr); ok { // Write the name of the generic method, including receiver type - b.WriteString(meth.Selection.Nname.Sym().Name) + sym = makeInstName(meth.Selection.Nname.Sym(), inst.Targs) } else { - b.WriteString(inst.X.(*ir.Name).Name().Sym().Name) + sym = makeInstName(inst.X.(*ir.Name).Name().Sym(), inst.Targs) } + //fmt.Printf("Found generic func call in %v to %v\n", f, s) + st := g.target.Stencils[sym] + if st == nil { + // If instantiation doesn't exist yet, create it and add + // to the list of decls. + st = g.genericSubst(sym, inst) + g.target.Stencils[sym] = st + g.target.Decls = append(g.target.Decls, st) + if base.Flag.W > 1 { + ir.Dump(fmt.Sprintf("\nstenciled %v", st), st) + } + } + return st +} + +// makeInstName makes the unique name for a stenciled generic function, based on +// the name of the function and the targs. +func makeInstName(fnsym *types.Sym, targs []ir.Node) *types.Sym { + b := bytes.NewBufferString("#") + b.WriteString(fnsym.Name) b.WriteString("[") - for i, targ := range inst.Targs { + for i, targ := range targs { if i > 0 { b.WriteString(",") } @@ -107,6 +154,7 @@ func makeInstName(inst *ir.InstExpr) *types.Sym { // Struct containing info needed for doing the substitution as we create the // instantiation of a generic function with specified type arguments. type subster struct { + g *irgen newf *ir.Func // Func node for the new stenciled function tparams []*types.Field targs []ir.Node @@ -121,7 +169,7 @@ type subster struct { // inst. For a method with a generic receiver, it returns an instantiated function // type where the receiver becomes the first parameter. Otherwise the instantiated // method would still need to be transformed by later compiler phases. -func genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func { +func (g *irgen) genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func { var nameNode *ir.Name var tparams []*types.Field if selExpr, ok := inst.X.(*ir.SelectorExpr); ok { @@ -148,6 +196,7 @@ func genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func { name.Def = newf.Nname subst := &subster{ + g: g, newf: newf, tparams: tparams, targs: inst.Targs, @@ -198,6 +247,9 @@ func (subst *subster) node(n ir.Node) ir.Node { return v } m := ir.NewNameAt(name.Pos(), name.Sym()) + if name.IsClosureVar() { + m.SetIsClosureVar(true) + } t := x.Type() newt := subst.typ(t) m.SetType(newt) @@ -219,10 +271,12 @@ func (subst *subster) node(n ir.Node) ir.Node { // t can be nil only if this is a call that has no // return values, so allow that and otherwise give // an error. - if _, isCallExpr := m.(*ir.CallExpr); !isCallExpr { + _, isCallExpr := m.(*ir.CallExpr) + _, isStructKeyExpr := m.(*ir.StructKeyExpr) + if !isCallExpr && !isStructKeyExpr { base.Fatalf(fmt.Sprintf("Nil type for %v", x)) } - } else { + } else if x.Op() != ir.OCLOSURE { m.SetType(subst.typ(x.Type())) } } @@ -270,14 +324,27 @@ func (subst *subster) node(n ir.Node) ir.Node { if oldfn.ClosureCalled() { newfn.SetClosureCalled(true) } + newfn.SetIsHiddenClosure(true) m.(*ir.ClosureExpr).Func = newfn - newfn.Nname = ir.NewNameAt(oldfn.Nname.Pos(), oldfn.Nname.Sym()) - newfn.Nname.SetType(oldfn.Nname.Type()) - newfn.Nname.Ntype = subst.node(oldfn.Nname.Ntype).(ir.Ntype) + newsym := makeInstName(oldfn.Nname.Sym(), subst.targs) + newfn.Nname = ir.NewNameAt(oldfn.Nname.Pos(), newsym) + newfn.Nname.Func = newfn + newfn.Nname.Defn = newfn + ir.MarkFunc(newfn.Nname) + newfn.OClosure = m.(*ir.ClosureExpr) + + saveNewf := subst.newf + subst.newf = newfn + newfn.Dcl = subst.namelist(oldfn.Dcl) + newfn.ClosureVars = subst.namelist(oldfn.ClosureVars) newfn.Body = subst.list(oldfn.Body) - // Make shallow copy of the Dcl and ClosureVar slices - newfn.Dcl = append([]*ir.Name(nil), oldfn.Dcl...) - newfn.ClosureVars = append([]*ir.Name(nil), oldfn.ClosureVars...) + subst.newf = saveNewf + + // Set Ntype for now to be compatible with later parts of compiler + newfn.Nname.Ntype = subst.node(oldfn.Nname.Ntype).(ir.Ntype) + typed(subst.typ(oldfn.Nname.Type()), newfn.Nname) + newfn.SetTypecheck(1) + subst.g.target.Decls = append(subst.g.target.Decls, newfn) } return m } @@ -285,6 +352,20 @@ func (subst *subster) node(n ir.Node) ir.Node { return edit(n) } +func (subst *subster) namelist(l []*ir.Name) []*ir.Name { + s := make([]*ir.Name, len(l)) + for i, n := range l { + s[i] = subst.node(n).(*ir.Name) + if n.Defn != nil { + s[i].Defn = subst.node(n.Defn) + } + if n.Outer != nil { + s[i].Outer = subst.node(n.Outer).(*ir.Name) + } + } + return s +} + func (subst *subster) list(l []ir.Node) []ir.Node { s := make([]ir.Node, len(l)) for i, n := range l { @@ -293,7 +374,9 @@ func (subst *subster) list(l []ir.Node) []ir.Node { return s } -// tstruct substitutes type params in a structure type +// tstruct substitutes type params in types of the fields of a structure type. For +// each field, if Nname is set, tstruct also translates the Nname using subst.vars, if +// Nname is in subst.vars. func (subst *subster) tstruct(t *types.Type) *types.Type { if t.NumFields() == 0 { return t @@ -301,7 +384,7 @@ func (subst *subster) tstruct(t *types.Type) *types.Type { var newfields []*types.Field for i, f := range t.Fields().Slice() { t2 := subst.typ(f.Type) - if t2 != f.Type && newfields == nil { + if (t2 != f.Type || f.Nname != nil) && newfields == nil { newfields = make([]*types.Field, t.NumFields()) for j := 0; j < i; j++ { newfields[j] = t.Field(j) @@ -309,6 +392,12 @@ func (subst *subster) tstruct(t *types.Type) *types.Type { } if newfields != nil { newfields[i] = types.NewField(f.Pos, f.Sym, t2) + if f.Nname != nil { + // f.Nname may not be in subst.vars[] if this is + // a function name or a function instantiation type + // that we are translating + newfields[i].Nname = subst.vars[f.Nname.(*ir.Name)] + } } } if newfields != nil { @@ -319,14 +408,14 @@ func (subst *subster) tstruct(t *types.Type) *types.Type { } // instTypeName creates a name for an instantiated type, based on the type args -func instTypeName(name string, targs []ir.Node) string { +func instTypeName(name string, targs []*types.Type) string { b := bytes.NewBufferString(name) b.WriteByte('[') for i, targ := range targs { if i > 0 { b.WriteByte(',') } - b.WriteString(targ.Type().String()) + b.WriteString(targ.String()) } b.WriteByte(']') return b.String() @@ -415,10 +504,17 @@ func (subst *subster) typ(t *types.Type) *types.Type { // Since we've substituted types, we also need to change // the defined name of the type, by removing the old types // (in brackets) from the name, and adding the new types. + + // Translate the type params for this type according to + // the tparam/targs mapping of the function. + neededTargs := make([]*types.Type, len(t.RParams)) + for i, rparam := range t.RParams { + neededTargs[i] = subst.typ(rparam) + } oldname := t.Sym().Name i := strings.Index(oldname, "[") oldname = oldname[:i] - sym := t.Sym().Pkg.Lookup(instTypeName(oldname, subst.targs)) + sym := t.Sym().Pkg.Lookup(instTypeName(oldname, neededTargs)) if sym.Def != nil { // We've already created this instantiated defined type. return sym.Def.Type() diff --git a/test/typeparam/combine.go b/test/typeparam/combine.go new file mode 100644 index 00000000000000..d4a2988a7b0e46 --- /dev/null +++ b/test/typeparam/combine.go @@ -0,0 +1,65 @@ +// run -gcflags=-G=3 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" +) + +type _Gen[A any] func() (A, bool) + +func combine[T1, T2, T any](g1 _Gen[T1], g2 _Gen[T2], join func(T1, T2) T) _Gen[T] { + return func() (T, bool) { + var t T + t1, ok := g1() + if !ok { + return t, false + } + t2, ok := g2() + if !ok { + return t, false + } + return join(t1, t2), true + } +} + +type _Pair[A, B any] struct { + A A + B B +} + +func _NewPair[A, B any](a A, b B) _Pair[A, B] { + return _Pair[A, B]{a, b} +} + +func _Combine2[A, B any](ga _Gen[A], gb _Gen[B]) _Gen[_Pair[A, B]] { + return combine(ga, gb, _NewPair[A, B]) +} + +func main() { + var g1 _Gen[int] = func() (int, bool) { return 3, true } + var g2 _Gen[string] = func() (string, bool) { return "x", false } + var g3 _Gen[string] = func() (string, bool) { return "y", true } + + gc := combine(g1, g2, _NewPair[int, string]) + if got, ok := gc(); ok { + panic(fmt.Sprintf("got %v, %v, wanted -/false", got, ok)) + } + gc2 := _Combine2(g1, g2) + if got, ok := gc2(); ok { + panic(fmt.Sprintf("got %v, %v, wanted -/false", got, ok)) + } + + gc3 := combine(g1, g3, _NewPair[int, string]) + if got, ok := gc3(); !ok || got.A != 3 || got.B != "y" { + panic(fmt.Sprintf("got %v, %v, wanted {3, y}, true", got, ok)) + } + gc4 := _Combine2(g1, g3) + if got, ok := gc4(); !ok || got.A != 3 || got.B != "y" { + panic (fmt.Sprintf("got %v, %v, wanted {3, y}, true", got, ok)) + } +} diff --git a/test/typeparam/settable.go b/test/typeparam/settable.go index 3bd141f7842754..7532953a773f00 100644 --- a/test/typeparam/settable.go +++ b/test/typeparam/settable.go @@ -11,7 +11,24 @@ import ( "strconv" ) -func fromStrings3[T any](s []string, set func(*T, string)) []T { +type Setter[B any] interface { + Set(string) + type *B +} + +func fromStrings1[T any, PT Setter[T]](s []string) []T { + result := make([]T, len(s)) + for i, v := range s { + // The type of &result[i] is *T which is in the type list + // of Setter, so we can convert it to PT. + p := PT(&result[i]) + // PT has a Set method. + p.Set(v) + } + return result +} + +func fromStrings2[T any](s []string, set func(*T, string)) []T { results := make([]T, len(s)) for i, v := range s { set(&results[i], v) @@ -30,8 +47,12 @@ func (p *Settable) Set(s string) { } func main() { - s := fromStrings3([]string{"1"}, - func(p *Settable, s string) { p.Set(s) }) + s := fromStrings1[Settable, *Settable]([]string{"1"}) + if len(s) != 1 || s[0] != 1 { + panic(fmt.Sprintf("got %v, want %v", s, []int{1})) + } + + s = fromStrings2([]string{"1"}, func(p *Settable, s string) { p.Set(s) }) if len(s) != 1 || s[0] != 1 { panic(fmt.Sprintf("got %v, want %v", s, []int{1})) }