diff --git a/_test/gen11.go b/_test/gen11.go new file mode 100644 index 000000000..82100f0d0 --- /dev/null +++ b/_test/gen11.go @@ -0,0 +1,33 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/netip" +) + +type Slice[T any] struct { + x []T +} + +type IPPrefixSlice struct { + x Slice[netip.Prefix] +} + +func (v Slice[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.x) } + +// MarshalJSON implements json.Marshaler. +func (v IPPrefixSlice) MarshalJSON() ([]byte, error) { + return v.x.MarshalJSON() +} + +func main() { + t := IPPrefixSlice{} + fmt.Println(t) + b, e := t.MarshalJSON() + fmt.Println(string(b), e) +} + +// Output: +// {{[]}} +// null diff --git a/_test/gen12.go b/_test/gen12.go new file mode 100644 index 000000000..d93298e2e --- /dev/null +++ b/_test/gen12.go @@ -0,0 +1,31 @@ +package main + +import ( + "fmt" +) + +func MapOf[K comparable, V any](m map[K]V) Map[K, V] { + return Map[K, V]{m} +} + +type Map[K comparable, V any] struct { + ж map[K]V +} + +func (v MapView) Int() Map[string, int] { return MapOf(v.ж.Int) } + +type VMap struct { + Int map[string]int +} + +type MapView struct { + ж *VMap +} + +func main() { + mv := MapView{&VMap{}} + fmt.Println(mv.ж) +} + +// Output: +// &{map[]} diff --git a/_test/gen13.go b/_test/gen13.go new file mode 100644 index 000000000..11c17ddb9 --- /dev/null +++ b/_test/gen13.go @@ -0,0 +1,18 @@ +package main + +type Map[K comparable, V any] struct { + ж map[K]V +} + +func (m Map[K, V]) Has(k K) bool { + _, ok := m.ж[k] + return ok +} + +func main() { + m := Map[string, float64]{} + println(m.Has("test")) +} + +// Output: +// false diff --git a/_test/issue-1460.go b/_test/issue-1460.go index ae5040454..44e14c8ff 100644 --- a/_test/issue-1460.go +++ b/_test/issue-1460.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "net/netip" "reflect" ) @@ -17,6 +18,10 @@ func unmarshalJSON[T any](b []byte, x *[]T) error { return json.Unmarshal(b, x) } +func SliceOfViews[T ViewCloner[T, V], V StructView[T]](x []T) SliceView[T, V] { + return SliceView[T, V]{x} +} + type StructView[T any] interface { Valid() bool AsStruct() T @@ -31,10 +36,6 @@ type ViewCloner[T any, V StructView[T]] interface { Clone() T } -func SliceOfViews[T ViewCloner[T, V], V StructView[T]](x []T) SliceView[T, V] { - return SliceView[T, V]{x} -} - func (v SliceView[T, V]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } func (v *SliceView[T, V]) UnmarshalJSON(b []byte) error { return unmarshalJSON(b, &v.ж) } @@ -51,6 +52,10 @@ func SliceOf[T any](x []T) Slice[T] { return Slice[T]{x} } +type IPPrefixSlice struct { + ж Slice[netip.Prefix] +} + type viewStruct struct { Int int Strings Slice[string] diff --git a/_test/issue-1488.go b/_test/issue-1488.go new file mode 100644 index 000000000..d26302a64 --- /dev/null +++ b/_test/issue-1488.go @@ -0,0 +1,23 @@ +package main + +import "fmt" + +type vector interface { + []int | [3]int +} + +func sum[V vector](v V) (out int) { + for i := 0; i < len(v); i++ { + out += v[i] + } + return +} + +func main() { + va := [3]int{1, 2, 3} + vs := []int{1, 2, 3} + fmt.Println(sum[[3]int](va), sum[[]int](vs)) +} + +// Output: +// 6 6 diff --git a/_test/p6.go b/_test/p6.go new file mode 100644 index 000000000..92436bc68 --- /dev/null +++ b/_test/p6.go @@ -0,0 +1,14 @@ +package main + +import ( + "fmt" + + "github.com/traefik/yaegi/_test/p6" +) + +func main() { + t := p6.IPPrefixSlice{} + fmt.Println(t) + b, e := t.MarshalJSON() + fmt.Println(string(b), e) +} diff --git a/_test/p6/p6.go b/_test/p6/p6.go new file mode 100644 index 000000000..52cb50b3f --- /dev/null +++ b/_test/p6/p6.go @@ -0,0 +1,21 @@ +package p6 + +import ( + "encoding/json" + "net/netip" +) + +type Slice[T any] struct { + x []T +} + +type IPPrefixSlice struct { + x Slice[netip.Prefix] +} + +func (v Slice[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.x) } + +// MarshalJSON implements json.Marshaler. +func (v IPPrefixSlice) MarshalJSON() ([]byte, error) { + return v.x.MarshalJSON() +} diff --git a/interp/cfg.go b/interp/cfg.go index a805e9ba7..47e646c66 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -322,8 +322,60 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string } } if n.typ == nil { - err = n.cfgErrorf("undefined type") - return false + // A nil type indicates either an error or a generic type. + // A child indexExpr or indexListExpr is used for type parameters, + // it indicates an instanciated generic. + if n.child[0].kind != indexExpr && n.child[0].kind != indexListExpr { + err = n.cfgErrorf("undefined type") + return false + } + t0, err1 := nodeType(interp, sc, n.child[0].child[0]) + if err1 != nil { + return false + } + if t0.cat != genericT { + err = n.cfgErrorf("undefined type") + return false + } + // We have a composite literal of generic type, instantiate it. + lt := []*itype{} + for _, n1 := range n.child[0].child[1:] { + t1, err1 := nodeType(interp, sc, n1) + if err1 != nil { + return false + } + lt = append(lt, t1) + } + var g *node + g, _, err = genAST(sc, t0.node.anc, lt) + if err != nil { + return false + } + n.child[0] = g.lastChild() + n.typ, err = nodeType(interp, sc, n.child[0]) + if err != nil { + return false + } + // Generate methods if any. + for _, nod := range t0.method { + gm, _, err2 := genAST(nod.scope, nod, lt) + if err2 != nil { + err = err2 + return false + } + gm.typ, err = nodeType(interp, nod.scope, gm.child[2]) + if err != nil { + return false + } + if _, err = interp.cfg(gm, sc, sc.pkgID, sc.pkgName); err != nil { + return false + } + if err = genRun(gm); err != nil { + return false + } + n.typ.addMethod(gm) + } + n.nleft = 1 // Indictate the type of composite literal. } } @@ -439,6 +491,19 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string if typ, err = nodeType(interp, sc, recvTypeNode); err != nil { return false } + if typ.cat == nilT { + // This may happen when instantiating generic methods. + s2, _, ok := sc.lookup(typ.id()) + if !ok { + err = n.cfgErrorf("type not found: %s", typ.id()) + break + } + typ = s2.typ + if typ.cat == nilT { + err = n.cfgErrorf("nil type: %s", typ.id()) + break + } + } recvTypeNode.typ = typ n.child[2].typ.recv = typ n.typ.recv = typ @@ -871,16 +936,18 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string n.typ = t return } - g, err := genAST(sc, t.node.anc, []*node{c1}) + g, found, err := genAST(sc, t.node.anc, []*itype{c1.typ}) if err != nil { return } - if _, err = interp.cfg(g, nil, importPath, pkgName); err != nil { - return - } - // Generate closures for function body. - if err = genRun(g.child[3]); err != nil { - return + if !found { + if _, err = interp.cfg(g, t.node.anc.scope, importPath, pkgName); err != nil { + return + } + // Generate closures for function body. + if err = genRun(g.child[3]); err != nil { + return + } } // Replace generic func node by instantiated one. n.anc.child[childPos(n)] = g @@ -1030,17 +1097,23 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string case c0.kind == indexListExpr: // Instantiate a generic function then call it. fun := c0.child[0].sym.node - g, err := genAST(sc, fun, c0.child[1:]) - if err != nil { - return + lt := []*itype{} + for _, c := range c0.child[1:] { + lt = append(lt, c.typ) } - _, err = interp.cfg(g, nil, importPath, pkgName) + g, found, err := genAST(sc, fun, lt) if err != nil { return } - err = genRun(g.child[3]) // Generate closures for function body. - if err != nil { - return + if !found { + _, err = interp.cfg(g, fun.scope, importPath, pkgName) + if err != nil { + return + } + err = genRun(g.child[3]) // Generate closures for function body. + if err != nil { + return + } } n.child[0] = g c0 = n.child[0] @@ -1212,23 +1285,26 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string if isGeneric(c0.typ) { fun := c0.typ.node.anc var g *node - var types []*node + var types []*itype + var found bool // Infer type parameter from function call arguments. if types, err = inferTypesFromCall(sc, fun, n.child[1:]); err != nil { break } // Generate an instantiated AST from the generic function one. - if g, err = genAST(sc, fun, types); err != nil { - break - } - // Compile the generated function AST, so it becomes part of the scope. - if _, err = interp.cfg(g, nil, importPath, pkgName); err != nil { + if g, found, err = genAST(sc, fun, types); err != nil { break } - // AST compilation part 2: Generate closures for function body. - if err = genRun(g.child[3]); err != nil { - break + if !found { + // Compile the generated function AST, so it becomes part of the scope. + if _, err = interp.cfg(g, fun.scope, importPath, pkgName); err != nil { + break + } + // AST compilation part 2: Generate closures for function body. + if err = genRun(g.child[3]); err != nil { + break + } } n.child[0] = g c0 = n.child[0] @@ -1487,6 +1563,10 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string sym, level, found := sc.lookup(n.ident) if !found { + if n.typ != nil { + // Node is a generic instance with an already populated type. + break + } // retry with the filename, in case ident is a package name. sym, level, found = sc.lookup(filepath.Join(n.ident, baseName)) if !found { @@ -1916,7 +1996,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string err = n.cfgErrorf("undefined selector: %s", n.child[1].ident) } } - if err == nil && n.findex != -1 { + if err == nil && n.findex != -1 && n.typ.cat != genericT { n.findex = sc.add(n.typ) } @@ -2375,11 +2455,13 @@ func (n *node) cfgErrorf(format string, a ...interface{}) *cfgError { func genRun(nod *node) error { var err error + seen := map[*node]bool{} nod.Walk(func(n *node) bool { - if err != nil { + if err != nil || seen[n] { return false } + seen[n] = true switch n.kind { case funcType: if len(n.anc.child) == 4 { diff --git a/interp/generic.go b/interp/generic.go index ec9ff3e40..da135642f 100644 --- a/interp/generic.go +++ b/interp/generic.go @@ -5,8 +5,11 @@ import ( "sync/atomic" ) +// adot produces an AST dot(1) directed acyclic graph for the given node. For debugging only. +// func (n *node) adot() { n.astDot(dotWriter(n.interp.dotCmd), n.ident) } + // genAST returns a new AST where generic types are replaced by instantiated types. -func genAST(sc *scope, root *node, types []*node) (*node, error) { +func genAST(sc *scope, root *node, types []*itype) (*node, bool, error) { typeParam := map[string]*node{} pindex := 0 tname := "" @@ -14,9 +17,20 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) { recvrPtr := false fixNodes := []*node{} var gtree func(*node, *node) (*node, error) + sname := root.child[0].ident + "[" + if root.kind == funcDecl { + sname = root.child[1].ident + "[" + } + + // Input type parameters must be resolved prior AST generation, as compilation + // of generated AST may occur in a different scope. + for _, t := range types { + sname += t.id() + "," + } + sname = strings.TrimSuffix(sname, ",") + "]" gtree = func(n, anc *node) (*node, error) { - nod := copyNode(n, anc) + nod := copyNode(n, anc, false) switch n.kind { case funcDecl, funcType: nod.val = nod @@ -27,7 +41,8 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) { if !ok { break } - nod = copyNode(nt, anc) + nod = copyNode(nt, anc, true) + nod.typ = nt.typ case indexExpr: // Catch a possible recursive generic type definition @@ -37,7 +52,7 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) { if root.child[0].ident != n.child[0].ident { break } - nod := copyNode(n.child[0], anc) + nod := copyNode(n.child[0], anc, false) fixNodes = append(fixNodes, nod) return nod, nil @@ -51,10 +66,16 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) { if pindex >= len(types) { return nil, cc.cfgErrorf("undefined type for %s", cc.ident) } - if err := checkConstraint(sc, types[pindex], c.child[l]); err != nil { + t, err := nodeType(c.interp, sc, c.child[l]) + if err != nil { + return nil, err + } + if err := checkConstraint(types[pindex], t); err != nil { return nil, err } - typeParam[cc.ident] = types[pindex] + typeParam[cc.ident] = copyNode(cc, cc.anc, false) + typeParam[cc.ident].ident = types[pindex].id() + typeParam[cc.ident].typ = types[pindex] pindex++ } } @@ -65,9 +86,9 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) { // Node is the receiver of a generic method. if root.kind == funcDecl && n.anc == root && childPos(n) == 0 && len(n.child) > 0 { rtn := n.child[0].child[1] - if rtn.kind == indexExpr || (rtn.kind == starExpr && rtn.child[0].kind == indexExpr) { - // Method receiver is a generic type. - if rtn.kind == starExpr && rtn.child[0].kind == indexExpr { + // Method receiver is a generic type if it takes some type parameters. + if rtn.kind == indexExpr || rtn.kind == indexListExpr || (rtn.kind == starExpr && (rtn.child[0].kind == indexExpr || rtn.child[0].kind == indexListExpr)) { + if rtn.kind == starExpr { // Method receiver is a pointer on a generic type. rtn = rtn.child[0] recvrPtr = true @@ -77,11 +98,10 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) { if pindex >= len(types) { return nil, cc.cfgErrorf("undefined type for %s", cc.ident) } - it, err := nodeType(n.interp, sc, types[pindex]) - if err != nil { - return nil, err - } - typeParam[cc.ident] = types[pindex] + it := types[pindex] + typeParam[cc.ident] = copyNode(cc, cc.anc, false) + typeParam[cc.ident].ident = it.id() + typeParam[cc.ident].typ = it rtname += it.id() + "," pindex++ } @@ -99,14 +119,17 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) { if pindex >= len(types) { return nil, cc.cfgErrorf("undefined type for %s", cc.ident) } - it, err := nodeType(n.interp, sc, types[pindex]) + it := types[pindex] + t, err := nodeType(c.interp, sc, c.child[l]) if err != nil { return nil, err } - if err := checkConstraint(sc, types[pindex], c.child[l]); err != nil { + if err := checkConstraint(types[pindex], t); err != nil { return nil, err } - typeParam[cc.ident] = types[pindex] + typeParam[cc.ident] = copyNode(cc, cc.anc, false) + typeParam[cc.ident].ident = it.id() + typeParam[cc.ident].typ = it tname += it.id() + "," pindex++ } @@ -115,6 +138,7 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) { return nod, nil } } + for _, c := range n.child { gn, err := gtree(c, nod) if err != nil { @@ -125,10 +149,16 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) { return nod, nil } + if nod, found := root.interp.generic[sname]; found { + return nod, true, nil + } + r, err := gtree(root, root.anc) if err != nil { - return nil, err + return nil, false, err } + root.interp.generic[sname] = r + r.param = append(r.param, types...) if tname != "" { for _, nod := range fixNodes { nod.ident = tname @@ -145,11 +175,11 @@ func genAST(sc *scope, root *node, types []*node) (*node, error) { nod.ident = rtname nod.child = nil } - // r.astDot(dotWriter(root.interp.dotCmd), root.child[1].ident) // Used for debugging only. - return r, nil + // r.adot() // Used for debugging only. + return r, false, nil } -func copyNode(n, anc *node) *node { +func copyNode(n, anc *node, recursive bool) *node { var i interface{} nindex := atomic.AddInt64(&n.interp.nindex, 1) nod := &node{ @@ -170,25 +200,30 @@ func copyNode(n, anc *node) *node { meta: n.meta, } nod.start = nod + if recursive { + for _, c := range n.child { + nod.child = append(nod.child, copyNode(c, nod, true)) + } + } return nod } -func inferTypesFromCall(sc *scope, fun *node, args []*node) ([]*node, error) { +func inferTypesFromCall(sc *scope, fun *node, args []*node) ([]*itype, error) { ftn := fun.typ.node // Fill the map of parameter types, indexed by type param ident. - types := map[string]*itype{} + paramTypes := map[string]*itype{} for _, c := range ftn.child[0].child { typ, err := nodeType(fun.interp, sc, c.lastChild()) if err != nil { return nil, err } for _, cc := range c.child[:len(c.child)-1] { - types[cc.ident] = typ + paramTypes[cc.ident] = typ } } - var inferTypes func(*itype, *itype) ([]*node, error) - inferTypes = func(param, input *itype) ([]*node, error) { + var inferTypes func(*itype, *itype) ([]*itype, error) + inferTypes = func(param, input *itype) ([]*itype, error) { switch param.cat { case chanT, ptrT, sliceT: return inferTypes(param.val, input.val) @@ -205,65 +240,68 @@ func inferTypesFromCall(sc *scope, fun *node, args []*node) ([]*node, error) { return append(k, v...), nil case structT: - nods := []*node{} + lt := []*itype{} for i, f := range param.field { nl, err := inferTypes(f.typ, input.field[i].typ) if err != nil { return nil, err } - nods = append(nods, nl...) + lt = append(lt, nl...) } - return nods, nil + return lt, nil case funcT: - nods := []*node{} + lt := []*itype{} for i, t := range param.arg { + if i >= len(input.arg) { + break + } nl, err := inferTypes(t, input.arg[i]) if err != nil { return nil, err } - nods = append(nods, nl...) + lt = append(lt, nl...) } for i, t := range param.ret { + if i >= len(input.ret) { + break + } nl, err := inferTypes(t, input.ret[i]) if err != nil { return nil, err } - nods = append(nods, nl...) + lt = append(lt, nl...) + } + return lt, nil + + case nilT: + if paramTypes[param.name] != nil { + return []*itype{input}, nil } - return nods, nil case genericT: - return []*node{input.node}, nil + return []*itype{input}, nil } return nil, nil } - nodes := []*node{} + types := []*itype{} for i, c := range ftn.child[1].child { typ, err := nodeType(fun.interp, sc, c.lastChild()) if err != nil { return nil, err } - nods, err := inferTypes(typ, args[i].typ) + lt, err := inferTypes(typ, args[i].typ) if err != nil { return nil, err } - nodes = append(nodes, nods...) + types = append(types, lt...) } - return nodes, nil + return types, nil } -func checkConstraint(sc *scope, input, constraint *node) error { - ct, err := nodeType(constraint.interp, sc, constraint) - if err != nil { - return err - } - it, err := nodeType(input.interp, sc, input) - if err != nil { - return err - } +func checkConstraint(it, ct *itype) error { if len(ct.constraint) == 0 && len(ct.ulconstraint) == 0 { return nil } @@ -277,5 +315,5 @@ func checkConstraint(sc *scope, input, constraint *node) error { return nil } } - return input.cfgErrorf("%s does not implement %s", input.typ.id(), ct.id()) + return it.node.cfgErrorf("%s does not implement %s", it.id(), ct.id()) } diff --git a/interp/gta.go b/interp/gta.go index 39e5c79dc..28f84aee2 100644 --- a/interp/gta.go +++ b/interp/gta.go @@ -21,6 +21,9 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ if err != nil { return false } + if n.scope == nil { + n.scope = sc + } switch n.kind { case constDecl: // Early parse of constDecl subtree, to compute all constant @@ -166,7 +169,7 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ typName = c.child[0].ident genericMethod = true } - case indexExpr: + case indexExpr, indexListExpr: genericMethod = true } } @@ -189,6 +192,14 @@ func (interp *Interpreter) gta(root *node, rpath, importPath, pkgName string) ([ } rcvrtype.addMethod(n) rtn.typ = rcvrtype + if rcvrtype.cat == genericT { + // generate methods for already instantiated receivers + for _, it := range rcvrtype.instance { + if err = genMethod(interp, sc, it, n, it.node.anc.param); err != nil { + return false + } + } + } case ident == "init": // init functions do not get declared as per the Go spec. default: diff --git a/interp/interp.go b/interp/interp.go index 3803d6fe9..11650a80b 100644 --- a/interp/interp.go +++ b/interp/interp.go @@ -28,6 +28,7 @@ type node struct { debug *nodeDebugData // debug info child []*node // child subtrees (AST) anc *node // ancestor (AST) + param []*itype // generic parameter nodes (AST) start *node // entry point in subtree (CFG) tnext *node // true branch successor (CFG) fnext *node // false branch successor (CFG) @@ -215,6 +216,7 @@ type Interpreter struct { pkgNames map[string]string // package names, indexed by import path done chan struct{} // for cancellation of channel operations roots []*node + generic map[string]*node hooks *hooks // symbol hooks @@ -335,6 +337,7 @@ func New(options Options) *Interpreter { pkgNames: map[string]string{}, rdir: map[string]bool{}, hooks: &hooks{}, + generic: map[string]*node{}, } if i.opt.stdin = options.Stdin; i.opt.stdin == nil { diff --git a/interp/type.go b/interp/type.go index fefa57117..7b036f9cb 100644 --- a/interp/type.go +++ b/interp/type.go @@ -126,6 +126,7 @@ type itype struct { method []*node // Associated methods or nil constraint []*itype // For interfaceT: list of types part of interface set ulconstraint []*itype // For interfaceT: list of underlying types part of interface set + instance []*itype // For genericT: list of instantiated types name string // name of type within its package for a defined type path string // for a defined type, the package import path length int // length of array if ArrayT @@ -786,6 +787,11 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, } else { t = sym.typ } + if t == nil { + if t, err = nodeType2(interp, sc, sym.node, seen); err != nil { + return nil, err + } + } if t.incomplete && t.cat == linkedT && t.val != nil && t.val.cat != nilT { t.incomplete = false } @@ -807,7 +813,11 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, return nil, err } if lt.incomplete { - t.incomplete = true + if t == nil { + t = lt + } else { + t.incomplete = true + } break } switch lt.cat { @@ -828,7 +838,7 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, break } // A generic type is being instantiated. Generate it. - t, err = genType(interp, sc, name, lt, []*node{t1.node}, seen) + t, err = genType(interp, sc, name, lt, []*itype{t1}, seen) if err != nil { return nil, err } @@ -840,6 +850,15 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, if lt, err = nodeType2(interp, sc, n.child[0], seen); err != nil { return nil, err } + if lt.incomplete { + if t == nil { + t = lt + } else { + t.incomplete = true + } + break + } + // Index list expressions can be used only in context of generic types. if lt.cat != genericT { err = n.cfgErrorf("not a generic type: %s", lt.id()) @@ -847,7 +866,7 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, } name := lt.id() + "[" out := false - tnodes := []*node{} + types := []*itype{} for _, c := range n.child[1:] { t1, err := nodeType2(interp, sc, c, seen) if err != nil { @@ -858,19 +877,19 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, out = true break } - tnodes = append(tnodes, t1.node) + types = append(types, t1) name += t1.id() + "," } if out { break } - name += "]" + name = strings.TrimSuffix(name, ",") + "]" if sym, _, found := sc.lookup(name); found { t = sym.typ break } // A generic type is being instantiated. Generate it. - t, err = genType(interp, sc, name, lt, tnodes, seen) + t, err = genType(interp, sc, name, lt, types, seen) case interfaceType: if sname := typeName(n); sname != "" { @@ -1016,7 +1035,7 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, sname := structName(n) if sname != "" { sym, _, found = sc.lookup(sname) - if found && sym.kind == typeSym { + if found && sym.kind == typeSym && sym.typ != nil { t = structOf(sym.typ, sym.typ.field, withNode(n), withScope(sc)) } else { t = structOf(nil, nil, withNode(n), withScope(sc)) @@ -1062,6 +1081,9 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, t = structOf(t, fields, withNode(n), withScope(sc)) t.incomplete = incomplete if sname != "" { + if sc.sym[sname] == nil { + sc.sym[sname] = &symbol{index: -1, kind: typeSym, node: n} + } sc.sym[sname].typ = t } @@ -1094,9 +1116,9 @@ func nodeType2(interp *Interpreter, sc *scope, n *node, seen []*node) (t *itype, return t, err } -func genType(interp *Interpreter, sc *scope, name string, lt *itype, tnodes, seen []*node) (t *itype, err error) { +func genType(interp *Interpreter, sc *scope, name string, lt *itype, types []*itype, seen []*node) (t *itype, err error) { // A generic type is being instantiated. Generate it. - g, err := genAST(sc, lt.node.anc, tnodes) + g, _, err := genAST(sc, lt.node.anc, types) if err != nil { return nil, err } @@ -1104,39 +1126,48 @@ func genType(interp *Interpreter, sc *scope, name string, lt *itype, tnodes, see if err != nil { return nil, err } + lt.instance = append(lt.instance, t) + // Add generated symbol in the scope of generic source and user. sc.sym[name] = &symbol{index: -1, kind: typeSym, typ: t, node: g} - - // Instantiate type methods (if any). - var pt *itype - if len(lt.method) > 0 { - pt = ptrOf(t, withNode(g), withScope(sc)) + if lt.scope.sym[name] == nil { + lt.scope.sym[name] = sc.sym[name] } + for _, nod := range lt.method { - gm, err := genAST(sc, nod, tnodes) - if err != nil { - return nil, err - } - if gm.typ, err = nodeType(interp, sc, gm.child[2]); err != nil { - return nil, err - } - t.addMethod(gm) - if rtn := gm.child[0].child[0].lastChild(); rtn.kind == starExpr { - // The receiver is a pointer on a generic type. - pt.addMethod(gm) - rtn.typ = pt - } - // Compile method CFG. - if _, err = interp.cfg(gm, sc, sc.pkgID, sc.pkgName); err != nil { - return nil, err - } - // Generate closures for function body. - if err = genRun(gm); err != nil { + if err := genMethod(interp, sc, t, nod, types); err != nil { return nil, err } } return t, err } +func genMethod(interp *Interpreter, sc *scope, t *itype, nod *node, types []*itype) error { + gm, _, err := genAST(sc, nod, types) + if err != nil { + return err + } + if gm.typ, err = nodeType(interp, sc, gm.child[2]); err != nil { + return err + } + t.addMethod(gm) + + // If the receiver is a pointer to a generic type, generate also the pointer type. + if rtn := gm.child[0].child[0].lastChild(); rtn != nil && rtn.kind == starExpr { + pt := ptrOf(t, withNode(t.node), withScope(sc)) + pt.addMethod(gm) + rtn.typ = pt + } + + // Compile the method AST in the scope of the generic type. + scop := nod.typ.scope + if _, err = interp.cfg(gm, scop, scop.pkgID, scop.pkgName); err != nil { + return err + } + + // Generate closures for function body. + return genRun(gm) +} + // findPackageType searches the top level scope for a package type. func findPackageType(interp *Interpreter, sc *scope, n *node) *itype { // Find the root scope, the package symbols will exist there.