From 6a6fd991e914cbc2859fd968949132085277a6d0 Mon Sep 17 00:00:00 2001 From: Tim King Date: Wed, 31 Jul 2024 13:04:53 -0700 Subject: [PATCH] go/ssa: substitute type parameterized aliases Adds support to substitute type parameterized aliases in generic functions. Change-Id: I4fb2e5f5fd9b626781efdc4db808c52cb22ba241 Reviewed-on: https://go-review.googlesource.com/c/tools/+/602195 Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- go/ssa/builder_generic_test.go | 26 +++--- go/ssa/builder_go122_test.go | 8 +- go/ssa/builder_test.go | 141 ++++++++++++++++++++++++++++++ go/ssa/subst.go | 83 ++++++++++++++++-- internal/aliases/aliases_go121.go | 13 +-- internal/aliases/aliases_go122.go | 28 ++++++ 6 files changed, 273 insertions(+), 26 deletions(-) diff --git a/go/ssa/builder_generic_test.go b/go/ssa/builder_generic_test.go index 33531dabffc..55dc79fe464 100644 --- a/go/ssa/builder_generic_test.go +++ b/go/ssa/builder_generic_test.go @@ -550,7 +550,13 @@ func TestGenericBodies(t *testing.T) { } // Collect calls to the builtin print function. - probes := callsTo(p, "print") + fns := make(map[*ssa.Function]bool) + for _, mem := range p.Members { + if fn, ok := mem.(*ssa.Function); ok { + fns[fn] = true + } + } + probes := callsTo(fns, "print") expectations := matchNotes(prog.Fset, notes, probes) for call := range probes { @@ -576,17 +582,15 @@ func TestGenericBodies(t *testing.T) { // callsTo finds all calls to an SSA value named fname, // and returns a map from each call site to its enclosing function. -func callsTo(p *ssa.Package, fname string) map[*ssa.CallCommon]*ssa.Function { +func callsTo(fns map[*ssa.Function]bool, fname string) map[*ssa.CallCommon]*ssa.Function { callsites := make(map[*ssa.CallCommon]*ssa.Function) - for _, mem := range p.Members { - if fn, ok := mem.(*ssa.Function); ok { - for _, bb := range fn.Blocks { - for _, i := range bb.Instrs { - if i, ok := i.(ssa.CallInstruction); ok { - call := i.Common() - if call.Value.Name() == fname { - callsites[call] = fn - } + for fn := range fns { + for _, bb := range fn.Blocks { + for _, i := range bb.Instrs { + if i, ok := i.(ssa.CallInstruction); ok { + call := i.Common() + if call.Value.Name() == fname { + callsites[call] = fn } } } diff --git a/go/ssa/builder_go122_test.go b/go/ssa/builder_go122_test.go index d98431296a7..bde5bae9292 100644 --- a/go/ssa/builder_go122_test.go +++ b/go/ssa/builder_go122_test.go @@ -168,7 +168,13 @@ func TestRangeOverInt(t *testing.T) { } // Collect calls to the built-in print function. - probes := callsTo(p, "print") + fns := make(map[*ssa.Function]bool) + for _, mem := range p.Members { + if fn, ok := mem.(*ssa.Function); ok { + fns[fn] = true + } + } + probes := callsTo(fns, "print") expectations := matchNotes(fset, notes, probes) for call := range probes { diff --git a/go/ssa/builder_test.go b/go/ssa/builder_test.go index ed1d84feeb9..f6fae50bb67 100644 --- a/go/ssa/builder_test.go +++ b/go/ssa/builder_test.go @@ -14,6 +14,7 @@ import ( "go/token" "go/types" "os" + "os/exec" "path/filepath" "reflect" "sort" @@ -1260,3 +1261,143 @@ func TestIssue67079(t *testing.T) { g.Wait() // ignore error } + +func TestGenericAliases(t *testing.T) { + testenv.NeedsGo1Point(t, 23) + + if os.Getenv("GENERICALIASTEST_CHILD") == "1" { + testGenericAliases(t) + return + } + + testenv.NeedsExec(t) + testenv.NeedsTool(t, "go") + + cmd := exec.Command(os.Args[0], "-test.run=TestGenericAliases") + cmd.Env = append(os.Environ(), + "GENERICALIASTEST_CHILD=1", + "GODEBUG=gotypesalias=1", + "GOEXPERIMENT=aliastypeparams", + ) + out, err := cmd.CombinedOutput() + if len(out) > 0 { + t.Logf("out=<<%s>>", out) + } + var exitcode int + if err, ok := err.(*exec.ExitError); ok { + exitcode = err.ExitCode() + } + const want = 0 + if exitcode != want { + t.Errorf("exited %d, want %d", exitcode, want) + } +} + +func testGenericAliases(t *testing.T) { + t.Setenv("GOEXPERIMENT", "aliastypeparams=1") + + const source = ` +package P + +type A = uint8 +type B[T any] = [4]T + +var F = f[string] + +func f[S any]() { + // Two copies of f are made: p.f[S] and p.f[string] + + var v A // application of A that is declared outside of f without no type arguments + print("p.f", "String", "p.A", v) + print("p.f", "==", v, uint8(0)) + print("p.f[string]", "String", "p.A", v) + print("p.f[string]", "==", v, uint8(0)) + + + var u B[S] // application of B that is declared outside declared outside of f with type arguments + print("p.f", "String", "p.B[S]", u) + print("p.f", "==", u, [4]S{}) + print("p.f[string]", "String", "p.B[string]", u) + print("p.f[string]", "==", u, [4]string{}) + + type C[T any] = struct{ s S; ap *B[T]} // declaration within f with type params + var w C[int] // application of C with type arguments + print("p.f", "String", "p.C[int]", w) + print("p.f", "==", w, struct{ s S; ap *[4]int}{}) + print("p.f[string]", "String", "p.C[int]", w) + print("p.f[string]", "==", w, struct{ s string; ap *[4]int}{}) +} +` + + conf := loader.Config{Fset: token.NewFileSet()} + f, err := parser.ParseFile(conf.Fset, "p.go", source, 0) + if err != nil { + t.Fatal(err) + } + conf.CreateFromFiles("p", f) + iprog, err := conf.Load() + if err != nil { + t.Fatal(err) + } + + // Create and build SSA program. + prog := ssautil.CreateProgram(iprog, ssa.InstantiateGenerics) + prog.Build() + + probes := callsTo(ssautil.AllFunctions(prog), "print") + if got, want := len(probes), 3*4*2; got != want { + t.Errorf("Found %v probes, expected %v", got, want) + } + + const debug = false // enable to debug skips + skipped := 0 + for probe, fn := range probes { + // Each probe is of the form: + // print("within", "test", head, tail) + // The probe only matches within a function whose fn.String() is within. + // This allows for different instantiations of fn to match different probes. + // On a match, it applies the test named "test" to head::tail. + if len(probe.Args) < 3 { + t.Fatalf("probe %v did not have enough arguments", probe) + } + within, test, head, tail := constString(probe.Args[0]), probe.Args[1], probe.Args[2], probe.Args[3:] + if within != fn.String() { + skipped++ + if debug { + t.Logf("Skipping %q within %q", within, fn.String()) + } + continue // does not match function + } + + switch test := constString(test); test { + case "==": // All of the values are types.Identical. + for _, v := range tail { + if !types.Identical(head.Type(), v.Type()) { + t.Errorf("Expected %v and %v to have identical types", head, v) + } + } + case "String": // head is a string constant that all values in tail must match Type().String() + want := constString(head) + for _, v := range tail { + if got := v.Type().String(); got != want { + t.Errorf("%s: %v had the Type().String()=%q. expected %q", within, v, got, want) + } + } + default: + t.Errorf("%q is not a test subcommand", test) + } + } + if want := 3 * 4; skipped != want { + t.Errorf("Skipped %d probes, expected to skip %d", skipped, want) + } +} + +// constString returns the value of a string constant +// or "" if the value is not a string constant. +func constString(v ssa.Value) string { + if c, ok := v.(*ssa.Const); ok { + str := c.Value.String() + return strings.Trim(str, `"`) + } + return "" +} diff --git a/go/ssa/subst.go b/go/ssa/subst.go index 75d887d7e52..4dcb871572d 100644 --- a/go/ssa/subst.go +++ b/go/ssa/subst.go @@ -318,15 +318,80 @@ func (subst *subster) interface_(iface *types.Interface) *types.Interface { } func (subst *subster) alias(t *aliases.Alias) types.Type { - // TODO(go.dev/issues/46477): support TypeParameters once these are available from go/types. - u := aliases.Unalias(t) - if s := subst.typ(u); s != u { - // If there is any change, do not create a new alias. - return s + // See subster.named. This follows the same strategy. + tparams := aliases.TypeParams(t) + targs := aliases.TypeArgs(t) + tname := t.Obj() + torigin := aliases.Origin(t) + + if !declaredWithin(tname, subst.origin) { + // t is declared outside of the function origin. So t is a package level type alias. + if targs.Len() == 0 { + // No type arguments so no instantiation needed. + return t + } + + // Instantiate with the substituted type arguments. + newTArgs := subst.typelist(targs) + return subst.instantiate(torigin, newTArgs) } - // If there is no change, t did not reach any type parameter. - // Keep the Alias. - return t + + if targs.Len() == 0 { + // t is declared within the function origin and has no type arguments. + // + // Example: This corresponds to A or B in F, but not A[int]: + // + // func F[T any]() { + // type A[S any] = struct{t T, s S} + // type B = T + // var x A[int] + // ... + // } + // + // This is somewhat different than *Named as *Alias cannot be created recursively. + + // Copy and substitute type params. + var newTParams []*types.TypeParam + for i := 0; i < tparams.Len(); i++ { + cur := tparams.At(i) + cobj := cur.Obj() + cname := types.NewTypeName(cobj.Pos(), cobj.Pkg(), cobj.Name(), nil) + ntp := types.NewTypeParam(cname, nil) + subst.cache[cur] = ntp // See the comment "Note: Subtle" in subster.named. + newTParams = append(newTParams, ntp) + } + + // Substitute rhs. + rhs := subst.typ(aliases.Rhs(t)) + + // Create the fresh alias. + obj := aliases.NewAlias(true, tname.Pos(), tname.Pkg(), tname.Name(), rhs) + fresh := obj.Type() + if fresh, ok := fresh.(*aliases.Alias); ok { + // TODO: assume ok when aliases are always materialized (go1.27). + aliases.SetTypeParams(fresh, newTParams) + } + + // Substitute into all of the constraints after they are created. + for i, ntp := range newTParams { + bound := tparams.At(i).Constraint() + ntp.SetConstraint(subst.typ(bound)) + } + return fresh + } + + // t is declared within the function origin and has type arguments. + // + // Example: This corresponds to A[int] in F. Cases A and B are handled above. + // func F[T any]() { + // type A[S any] = struct{t T, s S} + // type B = T + // var x A[int] + // ... + // } + subOrigin := subst.typ(torigin) + subTArgs := subst.typelist(targs) + return subst.instantiate(subOrigin, subTArgs) } func (subst *subster) named(t *types.Named) types.Type { @@ -456,7 +521,7 @@ func (subst *subster) named(t *types.Named) types.Type { func (subst *subster) instantiate(orig types.Type, targs []types.Type) types.Type { i, err := types.Instantiate(subst.ctxt, orig, targs, false) - assert(err == nil, "failed to Instantiate Named type") + assert(err == nil, "failed to Instantiate named (Named or Alias) type") if c, _ := subst.uniqueness.At(i).(types.Type); c != nil { return c.(types.Type) } diff --git a/internal/aliases/aliases_go121.go b/internal/aliases/aliases_go121.go index 63391e584b6..6652f7db0fb 100644 --- a/internal/aliases/aliases_go121.go +++ b/internal/aliases/aliases_go121.go @@ -15,11 +15,14 @@ import ( // It will never be created by go/types. type Alias struct{} -func (*Alias) String() string { panic("unreachable") } -func (*Alias) Underlying() types.Type { panic("unreachable") } -func (*Alias) Obj() *types.TypeName { panic("unreachable") } -func Rhs(alias *Alias) types.Type { panic("unreachable") } -func TypeParams(alias *Alias) *types.TypeParamList { panic("unreachable") } +func (*Alias) String() string { panic("unreachable") } +func (*Alias) Underlying() types.Type { panic("unreachable") } +func (*Alias) Obj() *types.TypeName { panic("unreachable") } +func Rhs(alias *Alias) types.Type { panic("unreachable") } +func TypeParams(alias *Alias) *types.TypeParamList { panic("unreachable") } +func SetTypeParams(alias *Alias, tparams []*types.TypeParam) { panic("unreachable") } +func TypeArgs(alias *Alias) *types.TypeList { panic("unreachable") } +func Origin(alias *Alias) *Alias { panic("unreachable") } // Unalias returns the type t for go <=1.21. func Unalias(t types.Type) types.Type { return t } diff --git a/internal/aliases/aliases_go122.go b/internal/aliases/aliases_go122.go index 96fcd166702..3ef1afeb403 100644 --- a/internal/aliases/aliases_go122.go +++ b/internal/aliases/aliases_go122.go @@ -36,6 +36,34 @@ func TypeParams(alias *Alias) *types.TypeParamList { return nil } +// SetTypeParams sets the type parameters of the alias type. +func SetTypeParams(alias *Alias, tparams []*types.TypeParam) { + if alias, ok := any(alias).(interface { + SetTypeParams(tparams []*types.TypeParam) + }); ok { + alias.SetTypeParams(tparams) // go1.23+ + } else if len(tparams) > 0 { + panic("cannot set type parameters of an Alias type in go1.22") + } +} + +// TypeArgs returns the type arguments used to instantiate the Alias type. +func TypeArgs(alias *Alias) *types.TypeList { + if alias, ok := any(alias).(interface{ TypeArgs() *types.TypeList }); ok { + return alias.TypeArgs() // go1.23+ + } + return nil // empty (go1.22) +} + +// Origin returns the generic Alias type of which alias is an instance. +// If alias is not an instance of a generic alias, Origin returns alias. +func Origin(alias *Alias) *Alias { + if alias, ok := any(alias).(interface{ Origin() *types.Alias }); ok { + return alias.Origin() // go1.23+ + } + return alias // not an instance of a generic alias (go1.22) +} + // Unalias is a wrapper of types.Unalias. func Unalias(t types.Type) types.Type { return types.Unalias(t) }