Skip to content

Commit

Permalink
[dev.typeparams] cmd/compile: allow generic funcs to call other gener…
Browse files Browse the repository at this point in the history
…ic funcs for stenciling

 - Handle generic function calling itself or another generic function in
   stenciling. This is easy - after it is created, just scan an
   instantiated generic function for function instantiations (that may
   needed to be stenciled), just like non-generic functions. The types
   in the function instantiation will already have been set by the
   stenciling.

 - Handle OTYPE nodes in subster.node() (allows for generic type
   conversions).

 - Eliminated some duplicated work in subster.typ().

 - Added new test case fact.go that tests a generic function calling
   itself, and simple generic type conversions.

 - Cause an error if a generic function is to be exported (which we
   don't handle yet).

 - Fixed some suggested changes in the add.go test case that I missed in
   the last review.

Change-Id: I5d61704254c27962f358d5a3d2e0c62a5099f148
Reviewed-on: https://go-review.googlesource.com/c/go/+/290469
Trust: Dan Scales <danscales@google.com>
Trust: Robert Griesemer <gri@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
  • Loading branch information
danscales committed Feb 8, 2021
1 parent dcb5e03 commit 0fbde54
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 14 deletions.
3 changes: 3 additions & 0 deletions src/cmd/compile/internal/noder/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ func (g *irgen) objFinish(name *ir.Name, class ir.Class, typ *types.Type) {
break // methods are exported with their receiver type
}
if types.IsExported(sym.Name) {
if name.Class == ir.PFUNC && name.Type().NumTParams() > 0 {
base.FatalfAt(name.Pos(), "Cannot export a generic function (yet): %v", name)
}
typecheck.Export(name)
}
if base.Flag.AsmHdr != "" && !name.Sym().Asm() {
Expand Down
15 changes: 11 additions & 4 deletions src/cmd/compile/internal/noder/stencil.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ import (
// creates the required stencils for simple generic functions.
func (g *irgen) stencil() {
g.target.Stencils = make(map[*types.Sym]*ir.Func)
for _, decl := range g.target.Decls {
// Don't use range(g.target.Decls) - we also want to process any new instantiated
// functions that are created during this loop, in order to handle generic
// functions calling other generic functions.
for i := 0; i < len(g.target.Decls); i++ {
decl := g.target.Decls[i]
if decl.Op() != ir.ODCLFUNC || decl.Type().NumTParams() > 0 {
// Skip any non-function declarations and skip generic functions
continue
Expand Down Expand Up @@ -142,6 +146,9 @@ func (subst *subster) node(n ir.Node) ir.Node {
var edit func(ir.Node) ir.Node
edit = func(x ir.Node) ir.Node {
switch x.Op() {
case ir.OTYPE:
return ir.TypeNode(subst.typ(x.Type()))

case ir.ONAME:
name := x.(*ir.Name)
if v := subst.vars[name]; v != nil {
Expand Down Expand Up @@ -211,21 +218,21 @@ func (subst *subster) typ(t *types.Type) *types.Type {
case types.TARRAY:
elem := t.Elem()
newelem := subst.typ(elem)
if subst.typ(elem) != elem {
if newelem != elem {
return types.NewArray(newelem, t.NumElem())
}

case types.TPTR:
elem := t.Elem()
newelem := subst.typ(elem)
if subst.typ(elem) != elem {
if newelem != elem {
return types.NewPtr(newelem)
}

case types.TSLICE:
elem := t.Elem()
newelem := subst.typ(elem)
if subst.typ(elem) != elem {
if newelem != elem {
return types.NewSlice(newelem)
}

Expand Down
35 changes: 35 additions & 0 deletions test/typeparam/fact.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// run -gcflags=-G=3

// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main

import (
"fmt"
)


func fact[T interface { type float64 }](n T) T {
if n == T(1) {
return T(1)
}
return n * fact(n - T(1))
}

func main() {
got := fact(4.0)
want := 24.0
if got != want {
panic(fmt.Sprintf("Got %f, want %f", got, want))
}

// Re-enable when types2 bug is fixed (can't do T(1) with more than one
// type in the type list).
//got = fact(5)
//want = 120
//if want != got {
// panic(fmt.Sprintf("Want %d, got %d", want, got))
//}
}
20 changes: 10 additions & 10 deletions test/typeparam/add.go → test/typeparam/sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"fmt"
)

func add[T interface{ type int, float64 }](vec []T) T {
func sum[T interface{ type int, float64 }](vec []T) T {
var sum T
for _, elt := range vec {
sum = sum + elt
Expand All @@ -28,23 +28,23 @@ func abs(f float64) float64 {
func main() {
vec1 := []int{3, 4}
vec2 := []float64{5.8, 9.6}
got := sum[int](vec1)
want := vec1[0] + vec1[1]
got := add[int](vec1)
if want != got {
panic(fmt.Sprintf("Want %d, got %d", want, got))
if got != want {
panic(fmt.Sprintf("Got %d, want %d", got, want))
}
got = add(vec1)
got = sum(vec1)
if want != got {
panic(fmt.Sprintf("Want %d, got %d", want, got))
panic(fmt.Sprintf("Got %d, want %d", got, want))
}

fwant := vec2[0] + vec2[1]
fgot := add[float64](vec2)
fgot := sum[float64](vec2)
if abs(fgot - fwant) > 1e-10 {
panic(fmt.Sprintf("Want %f, got %f", fwant, fgot))
panic(fmt.Sprintf("Got %f, want %f", fgot, fwant))
}
fgot = add(vec2)
fgot = sum(vec2)
if abs(fgot - fwant) > 1e-10 {
panic(fmt.Sprintf("Want %f, got %f", fwant, fgot))
panic(fmt.Sprintf("Got %f, want %f", fgot, fwant))
}
}

0 comments on commit 0fbde54

Please sign in to comment.