diff --git a/src/cmd/compile/internal/noder/object.go b/src/cmd/compile/internal/noder/object.go index c740285ca2ea3b..b4e5c022dbfe8d 100644 --- a/src/cmd/compile/internal/noder/object.go +++ b/src/cmd/compile/internal/noder/object.go @@ -155,6 +155,9 @@ func (g *irgen) objFinish(name *ir.Name, class ir.Class, typ *types.Type) { break // methods are exported with their receiver type } if types.IsExported(sym.Name) { + if name.Class == ir.PFUNC && name.Type().NumTParams() > 0 { + base.FatalfAt(name.Pos(), "Cannot export a generic function (yet): %v", name) + } typecheck.Export(name) } if base.Flag.AsmHdr != "" && !name.Sym().Asm() { diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go index 3c6c7f4a8c15aa..0c4eadcf448549 100644 --- a/src/cmd/compile/internal/noder/stencil.go +++ b/src/cmd/compile/internal/noder/stencil.go @@ -20,7 +20,11 @@ import ( // creates the required stencils for simple generic functions. func (g *irgen) stencil() { g.target.Stencils = make(map[*types.Sym]*ir.Func) - for _, decl := range g.target.Decls { + // Don't use range(g.target.Decls) - we also want to process any new instantiated + // functions that are created during this loop, in order to handle generic + // functions calling other generic functions. + for i := 0; i < len(g.target.Decls); i++ { + decl := g.target.Decls[i] if decl.Op() != ir.ODCLFUNC || decl.Type().NumTParams() > 0 { // Skip any non-function declarations and skip generic functions continue @@ -142,6 +146,9 @@ func (subst *subster) node(n ir.Node) ir.Node { var edit func(ir.Node) ir.Node edit = func(x ir.Node) ir.Node { switch x.Op() { + case ir.OTYPE: + return ir.TypeNode(subst.typ(x.Type())) + case ir.ONAME: name := x.(*ir.Name) if v := subst.vars[name]; v != nil { @@ -211,21 +218,21 @@ func (subst *subster) typ(t *types.Type) *types.Type { case types.TARRAY: elem := t.Elem() newelem := subst.typ(elem) - if subst.typ(elem) != elem { + if newelem != elem { return types.NewArray(newelem, t.NumElem()) } case types.TPTR: elem := t.Elem() newelem := subst.typ(elem) - if subst.typ(elem) != elem { + if newelem != elem { return types.NewPtr(newelem) } case types.TSLICE: elem := t.Elem() newelem := subst.typ(elem) - if subst.typ(elem) != elem { + if newelem != elem { return types.NewSlice(newelem) } diff --git a/test/typeparam/fact.go b/test/typeparam/fact.go new file mode 100644 index 00000000000000..e5e0ad4ff3787c --- /dev/null +++ b/test/typeparam/fact.go @@ -0,0 +1,35 @@ +// run -gcflags=-G=3 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" +) + + +func fact[T interface { type float64 }](n T) T { + if n == T(1) { + return T(1) + } + return n * fact(n - T(1)) +} + +func main() { + got := fact(4.0) + want := 24.0 + if got != want { + panic(fmt.Sprintf("Got %f, want %f", got, want)) + } + + // Re-enable when types2 bug is fixed (can't do T(1) with more than one + // type in the type list). + //got = fact(5) + //want = 120 + //if want != got { + // panic(fmt.Sprintf("Want %d, got %d", want, got)) + //} +} diff --git a/test/typeparam/add.go b/test/typeparam/sum.go similarity index 61% rename from test/typeparam/add.go rename to test/typeparam/sum.go index b0cf76d3ee313b..72511c2fe5b4ad 100644 --- a/test/typeparam/add.go +++ b/test/typeparam/sum.go @@ -10,7 +10,7 @@ import ( "fmt" ) -func add[T interface{ type int, float64 }](vec []T) T { +func sum[T interface{ type int, float64 }](vec []T) T { var sum T for _, elt := range vec { sum = sum + elt @@ -28,23 +28,23 @@ func abs(f float64) float64 { func main() { vec1 := []int{3, 4} vec2 := []float64{5.8, 9.6} + got := sum[int](vec1) want := vec1[0] + vec1[1] - got := add[int](vec1) - if want != got { - panic(fmt.Sprintf("Want %d, got %d", want, got)) + if got != want { + panic(fmt.Sprintf("Got %d, want %d", got, want)) } - got = add(vec1) + got = sum(vec1) if want != got { - panic(fmt.Sprintf("Want %d, got %d", want, got)) + panic(fmt.Sprintf("Got %d, want %d", got, want)) } fwant := vec2[0] + vec2[1] - fgot := add[float64](vec2) + fgot := sum[float64](vec2) if abs(fgot - fwant) > 1e-10 { - panic(fmt.Sprintf("Want %f, got %f", fwant, fgot)) + panic(fmt.Sprintf("Got %f, want %f", fgot, fwant)) } - fgot = add(vec2) + fgot = sum(vec2) if abs(fgot - fwant) > 1e-10 { - panic(fmt.Sprintf("Want %f, got %f", fwant, fgot)) + panic(fmt.Sprintf("Got %f, want %f", fgot, fwant)) } }