Skip to content

Commit

Permalink
interp: improve handling of generic types
Browse files Browse the repository at this point in the history
When generating a new type, the parameter type was not correctly duplicated in the new AST. This is fixed by making copyNode recursive if needed. The out of order processing of generic types has also been fixed.

Fixes #1488
  • Loading branch information
mvertes authored Feb 8, 2023
1 parent 0e3ea57 commit f3dbce9
Show file tree
Hide file tree
Showing 12 changed files with 423 additions and 113 deletions.
33 changes: 33 additions & 0 deletions _test/gen11.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package main

import (
"encoding/json"
"fmt"
"net/netip"
)

type Slice[T any] struct {
x []T
}

type IPPrefixSlice struct {
x Slice[netip.Prefix]
}

func (v Slice[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.x) }

// MarshalJSON implements json.Marshaler.
func (v IPPrefixSlice) MarshalJSON() ([]byte, error) {
return v.x.MarshalJSON()
}

func main() {
t := IPPrefixSlice{}
fmt.Println(t)
b, e := t.MarshalJSON()
fmt.Println(string(b), e)
}

// Output:
// {{[]}}
// null <nil>
31 changes: 31 additions & 0 deletions _test/gen12.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package main

import (
"fmt"
)

func MapOf[K comparable, V any](m map[K]V) Map[K, V] {
return Map[K, V]{m}
}

type Map[K comparable, V any] struct {
ж map[K]V
}

func (v MapView) Int() Map[string, int] { return MapOf(v.ж.Int) }

type VMap struct {
Int map[string]int
}

type MapView struct {
ж *VMap
}

func main() {
mv := MapView{&VMap{}}
fmt.Println(mv.ж)
}

// Output:
// &{map[]}
18 changes: 18 additions & 0 deletions _test/gen13.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package main

type Map[K comparable, V any] struct {
ж map[K]V
}

func (m Map[K, V]) Has(k K) bool {
_, ok := m.ж[k]
return ok
}

func main() {
m := Map[string, float64]{}
println(m.Has("test"))
}

// Output:
// false
13 changes: 9 additions & 4 deletions _test/issue-1460.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"errors"
"net/netip"
"reflect"
)

Expand All @@ -17,6 +18,10 @@ func unmarshalJSON[T any](b []byte, x *[]T) error {
return json.Unmarshal(b, x)
}

func SliceOfViews[T ViewCloner[T, V], V StructView[T]](x []T) SliceView[T, V] {
return SliceView[T, V]{x}
}

type StructView[T any] interface {
Valid() bool
AsStruct() T
Expand All @@ -31,10 +36,6 @@ type ViewCloner[T any, V StructView[T]] interface {
Clone() T
}

func SliceOfViews[T ViewCloner[T, V], V StructView[T]](x []T) SliceView[T, V] {
return SliceView[T, V]{x}
}

func (v SliceView[T, V]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }

func (v *SliceView[T, V]) UnmarshalJSON(b []byte) error { return unmarshalJSON(b, &v.ж) }
Expand All @@ -51,6 +52,10 @@ func SliceOf[T any](x []T) Slice[T] {
return Slice[T]{x}
}

type IPPrefixSlice struct {
ж Slice[netip.Prefix]
}

type viewStruct struct {
Int int
Strings Slice[string]
Expand Down
23 changes: 23 additions & 0 deletions _test/issue-1488.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package main

import "fmt"

type vector interface {
[]int | [3]int
}

func sum[V vector](v V) (out int) {
for i := 0; i < len(v); i++ {
out += v[i]
}
return
}

func main() {
va := [3]int{1, 2, 3}
vs := []int{1, 2, 3}
fmt.Println(sum[[3]int](va), sum[[]int](vs))
}

// Output:
// 6 6
14 changes: 14 additions & 0 deletions _test/p6.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package main

import (
"fmt"

"github.com/traefik/yaegi/_test/p6"
)

func main() {
t := p6.IPPrefixSlice{}
fmt.Println(t)
b, e := t.MarshalJSON()
fmt.Println(string(b), e)
}
21 changes: 21 additions & 0 deletions _test/p6/p6.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package p6

import (
"encoding/json"
"net/netip"
)

type Slice[T any] struct {
x []T
}

type IPPrefixSlice struct {
x Slice[netip.Prefix]
}

func (v Slice[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.x) }

// MarshalJSON implements json.Marshaler.
func (v IPPrefixSlice) MarshalJSON() ([]byte, error) {
return v.x.MarshalJSON()
}
136 changes: 109 additions & 27 deletions interp/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,60 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
}
}
if n.typ == nil {
err = n.cfgErrorf("undefined type")
return false
// A nil type indicates either an error or a generic type.
// A child indexExpr or indexListExpr is used for type parameters,
// it indicates an instanciated generic.
if n.child[0].kind != indexExpr && n.child[0].kind != indexListExpr {
err = n.cfgErrorf("undefined type")
return false
}
t0, err1 := nodeType(interp, sc, n.child[0].child[0])
if err1 != nil {
return false
}
if t0.cat != genericT {
err = n.cfgErrorf("undefined type")
return false
}
// We have a composite literal of generic type, instantiate it.
lt := []*itype{}
for _, n1 := range n.child[0].child[1:] {
t1, err1 := nodeType(interp, sc, n1)
if err1 != nil {
return false
}
lt = append(lt, t1)
}
var g *node
g, _, err = genAST(sc, t0.node.anc, lt)
if err != nil {
return false
}
n.child[0] = g.lastChild()
n.typ, err = nodeType(interp, sc, n.child[0])
if err != nil {
return false
}
// Generate methods if any.
for _, nod := range t0.method {
gm, _, err2 := genAST(nod.scope, nod, lt)
if err2 != nil {
err = err2
return false
}
gm.typ, err = nodeType(interp, nod.scope, gm.child[2])
if err != nil {
return false
}
if _, err = interp.cfg(gm, sc, sc.pkgID, sc.pkgName); err != nil {
return false
}
if err = genRun(gm); err != nil {
return false
}
n.typ.addMethod(gm)
}
n.nleft = 1 // Indictate the type of composite literal.
}
}

Expand Down Expand Up @@ -439,6 +491,19 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
if typ, err = nodeType(interp, sc, recvTypeNode); err != nil {
return false
}
if typ.cat == nilT {
// This may happen when instantiating generic methods.
s2, _, ok := sc.lookup(typ.id())
if !ok {
err = n.cfgErrorf("type not found: %s", typ.id())
break
}
typ = s2.typ
if typ.cat == nilT {
err = n.cfgErrorf("nil type: %s", typ.id())
break
}
}
recvTypeNode.typ = typ
n.child[2].typ.recv = typ
n.typ.recv = typ
Expand Down Expand Up @@ -871,16 +936,18 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
n.typ = t
return
}
g, err := genAST(sc, t.node.anc, []*node{c1})
g, found, err := genAST(sc, t.node.anc, []*itype{c1.typ})
if err != nil {
return
}
if _, err = interp.cfg(g, nil, importPath, pkgName); err != nil {
return
}
// Generate closures for function body.
if err = genRun(g.child[3]); err != nil {
return
if !found {
if _, err = interp.cfg(g, t.node.anc.scope, importPath, pkgName); err != nil {
return
}
// Generate closures for function body.
if err = genRun(g.child[3]); err != nil {
return
}
}
// Replace generic func node by instantiated one.
n.anc.child[childPos(n)] = g
Expand Down Expand Up @@ -1030,17 +1097,23 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
case c0.kind == indexListExpr:
// Instantiate a generic function then call it.
fun := c0.child[0].sym.node
g, err := genAST(sc, fun, c0.child[1:])
if err != nil {
return
lt := []*itype{}
for _, c := range c0.child[1:] {
lt = append(lt, c.typ)
}
_, err = interp.cfg(g, nil, importPath, pkgName)
g, found, err := genAST(sc, fun, lt)
if err != nil {
return
}
err = genRun(g.child[3]) // Generate closures for function body.
if err != nil {
return
if !found {
_, err = interp.cfg(g, fun.scope, importPath, pkgName)
if err != nil {
return
}
err = genRun(g.child[3]) // Generate closures for function body.
if err != nil {
return
}
}
n.child[0] = g
c0 = n.child[0]
Expand Down Expand Up @@ -1212,23 +1285,26 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
if isGeneric(c0.typ) {
fun := c0.typ.node.anc
var g *node
var types []*node
var types []*itype
var found bool

// Infer type parameter from function call arguments.
if types, err = inferTypesFromCall(sc, fun, n.child[1:]); err != nil {
break
}
// Generate an instantiated AST from the generic function one.
if g, err = genAST(sc, fun, types); err != nil {
break
}
// Compile the generated function AST, so it becomes part of the scope.
if _, err = interp.cfg(g, nil, importPath, pkgName); err != nil {
if g, found, err = genAST(sc, fun, types); err != nil {
break
}
// AST compilation part 2: Generate closures for function body.
if err = genRun(g.child[3]); err != nil {
break
if !found {
// Compile the generated function AST, so it becomes part of the scope.
if _, err = interp.cfg(g, fun.scope, importPath, pkgName); err != nil {
break
}
// AST compilation part 2: Generate closures for function body.
if err = genRun(g.child[3]); err != nil {
break
}
}
n.child[0] = g
c0 = n.child[0]
Expand Down Expand Up @@ -1487,6 +1563,10 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string

sym, level, found := sc.lookup(n.ident)
if !found {
if n.typ != nil {
// Node is a generic instance with an already populated type.
break
}
// retry with the filename, in case ident is a package name.
sym, level, found = sc.lookup(filepath.Join(n.ident, baseName))
if !found {
Expand Down Expand Up @@ -1916,7 +1996,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
err = n.cfgErrorf("undefined selector: %s", n.child[1].ident)
}
}
if err == nil && n.findex != -1 {
if err == nil && n.findex != -1 && n.typ.cat != genericT {
n.findex = sc.add(n.typ)
}

Expand Down Expand Up @@ -2375,11 +2455,13 @@ func (n *node) cfgErrorf(format string, a ...interface{}) *cfgError {

func genRun(nod *node) error {
var err error
seen := map[*node]bool{}

nod.Walk(func(n *node) bool {
if err != nil {
if err != nil || seen[n] {
return false
}
seen[n] = true
switch n.kind {
case funcType:
if len(n.anc.child) == 4 {
Expand Down
Loading

0 comments on commit f3dbce9

Please sign in to comment.