From e003140c6e76617b78ee3a5e783f05d3373d091c Mon Sep 17 00:00:00 2001 From: Marc Vertes Date: Wed, 19 Oct 2022 17:54:08 +0200 Subject: [PATCH] interp: improve internal handling of functions Up to now functions could be stored as node values in frame (as for interpreter defined functions) or function values, directly callable by the Go runtime. We now always store functions in the later form, making the processing of functions, anonymous closures and methods simpler and more robust. All functions, once compiled are always directly callable, with no further wrapping necessary. Fixes #1459. --- _test/cli8.go | 41 ++++++++++ _test/convert3.go | 18 +++++ _test/issue-1459.go | 22 ++++++ _test/struct49.go | 2 +- generate.go | 1 + internal/unsafe2/unsafe.go | 1 + interp/cfg.go | 7 +- interp/run.go | 148 +++++++++++++++++++------------------ interp/type.go | 28 +++---- interp/value.go | 117 ++++++----------------------- 10 files changed, 196 insertions(+), 189 deletions(-) create mode 100644 _test/cli8.go create mode 100644 _test/convert3.go create mode 100644 _test/issue-1459.go diff --git a/_test/cli8.go b/_test/cli8.go new file mode 100644 index 000000000..bdd5b4a00 --- /dev/null +++ b/_test/cli8.go @@ -0,0 +1,41 @@ +package main + +import ( + "net/http" + "net/http/httptest" +) + +type T struct { + name string + next http.Handler +} + +func (t *T) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + println("in T.ServeHTTP") + if t.next != nil { + t.next.ServeHTTP(rw, req) + } +} + +func New(name string, next http.Handler) (http.Handler, error) { return &T{name, next}, nil } + +func main() { + next := func(rw http.ResponseWriter, req *http.Request) { + println("in next") + } + + t, err := New("test", http.HandlerFunc(next)) + if err != nil { + panic(err) + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + t.ServeHTTP(recorder, req) + println(recorder.Result().Status) +} + +// Output: +// in T.ServeHTTP +// in next +// 200 OK diff --git a/_test/convert3.go b/_test/convert3.go new file mode 100644 index 000000000..58166e2eb --- /dev/null +++ b/_test/convert3.go @@ -0,0 +1,18 @@ +package main + +import ( + "fmt" + "net/http" +) + +func main() { + next := func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("Cache-Control", "max-age=20") + rw.WriteHeader(http.StatusOK) + } + f := http.HandlerFunc(next) + fmt.Printf("%T\n", f.ServeHTTP) +} + +// Output: +// func(http.ResponseWriter, *http.Request) diff --git a/_test/issue-1459.go b/_test/issue-1459.go new file mode 100644 index 000000000..148eb75de --- /dev/null +++ b/_test/issue-1459.go @@ -0,0 +1,22 @@ +package main + +import "fmt" + +type funclistItem func() + +type funclist struct { + list []funclistItem +} + +func main() { + funcs := funclist{} + + funcs.list = append(funcs.list, func() { fmt.Println("first") }) + + for _, f := range funcs.list { + f() + } +} + +// Output: +// first diff --git a/_test/struct49.go b/_test/struct49.go index 8e574770c..c8705d0e9 100644 --- a/_test/struct49.go +++ b/_test/struct49.go @@ -25,7 +25,7 @@ func main() { } s.ts["test"] = append(s.ts["test"], &T{s: s}) - t , ok:= s.getT("test") + t, ok := s.getT("test") println(t != nil, ok) } diff --git a/generate.go b/generate.go index d80bb7250..ab7fa228c 100644 --- a/generate.go +++ b/generate.go @@ -1,3 +1,4 @@ +// Package yaegi provides a Go interpreter. package yaegi //go:generate go generate github.com/traefik/yaegi/internal/cmd/extract diff --git a/internal/unsafe2/unsafe.go b/internal/unsafe2/unsafe.go index 47f96ad18..4a4b24d95 100644 --- a/internal/unsafe2/unsafe.go +++ b/internal/unsafe2/unsafe.go @@ -1,3 +1,4 @@ +// Package unsafe2 provides helpers to generate recursive struct types. package unsafe2 import ( diff --git a/interp/cfg.go b/interp/cfg.go index c77f2c57f..778bbb884 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -1707,12 +1707,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string } if c.typ.cat == nilT { // nil: Set node value to zero of return type - if typ.cat == funcT { - // Wrap the typed nil value in a node, as per other interpreter functions - c.rval = reflect.ValueOf(&node{kind: basicLit, rval: reflect.New(typ.TypeOf()).Elem()}) - } else { - c.rval = reflect.New(typ.TypeOf()).Elem() - } + c.rval = reflect.New(typ.TypeOf()).Elem() } } diff --git a/interp/run.go b/interp/run.go index c592b606d..a514374c9 100644 --- a/interp/run.go +++ b/interp/run.go @@ -566,29 +566,11 @@ func convert(n *node) { return } - if isFuncSrc(n.child[0].typ) && isFuncSrc(c.typ) { - value := genValue(c) - n.exec = func(f *frame) bltn { - n, ok := value(f).Interface().(*node) - if !ok || !n.typ.convertibleTo(c.typ) { - panic(n.cfgErrorf("cannot convert to %s", c.typ.id())) - } - n1 := *n - n1.typ = c.typ - dest(f).Set(reflect.ValueOf(&n1)) - return next - } - return - } - doConvert := true var value func(*frame) reflect.Value switch { case isFuncSrc(c.typ): value = genFunctionWrapper(c) - case isFuncSrc(n.child[0].typ) && c.typ.cat == valueT: - doConvert = false - value = genValueNode(c) default: value = genValue(c) } @@ -659,8 +641,8 @@ func assign(n *node) { for i := 0; i < n.nleft; i++ { dest, src := n.child[i], n.child[sbase+i] - if isFuncSrc(src.typ) && isField(dest) { - svalue[i] = genFunctionWrapper(src) + if isNamedFuncSrc(src.typ) { + svalue[i] = genFuncValue(src) } else { svalue[i] = genDestValue(dest.typ, src) } @@ -714,8 +696,6 @@ func assign(n *node) { for i := range types { var t reflect.Type switch typ := n.child[sbase+i].typ; { - case isFuncSrc(typ): - t = reflect.TypeOf((*node)(nil)) case isInterfaceSrc(typ): t = valueInterfaceType default: @@ -952,9 +932,6 @@ func genFunctionWrapper(n *node) func(*frame) reflect.Value { var def *node var ok bool - if n.kind == basicLit { - return func(f *frame) reflect.Value { return n.rval } - } if def, ok = n.val.(*node); !ok { return genValueAsFunctionWrapper(n) } @@ -963,11 +940,7 @@ func genFunctionWrapper(n *node) func(*frame) reflect.Value { var rcvr func(*frame) reflect.Value if n.recv != nil { - if n.recv.node.typ.cat != defRecvType(def).cat { - rcvr = genValueRecvIndirect(n) - } else { - rcvr = genValueRecv(n) - } + rcvr = genValueRecv(n) } funcType := n.typ.TypeOf() @@ -989,15 +962,20 @@ func genFunctionWrapper(n *node) func(*frame) reflect.Value { // Copy method receiver as first argument. src, dest := rcvr(f), d[numRet] sk, dk := src.Kind(), dest.Kind() + for { + vs, ok := src.Interface().(valueInterface) + if !ok { + break + } + src = vs.value + sk = src.Kind() + } switch { case sk == reflect.Ptr && dk != reflect.Ptr: dest.Set(src.Elem()) case sk != reflect.Ptr && dk == reflect.Ptr: dest.Set(src.Addr()) default: - if wrappedSrc, ok := src.Interface().(valueInterface); ok { - src = wrappedSrc.value - } dest.Set(src) } d = d[numRet+1:] @@ -1015,8 +993,6 @@ func genFunctionWrapper(n *node) func(*frame) reflect.Value { d[i].Set(arg) case isInterfaceSrc(typ): d[i].Set(reflect.ValueOf(valueInterface{value: arg.Elem()})) - case isFuncSrc(typ) && arg.Kind() == reflect.Func: - d[i].Set(reflect.ValueOf(genFunctionNode(arg))) default: d[i].Set(arg) } @@ -1025,21 +1001,11 @@ func genFunctionWrapper(n *node) func(*frame) reflect.Value { // Interpreter code execution. runCfg(start, fr, def, n) - result := fr.data[:numRet] - for i, r := range result { - if v, ok := r.Interface().(*node); ok { - result[i] = genFunctionWrapper(v)(f) - } - } - return result + return fr.data[:numRet] }) } } -func genFunctionNode(v reflect.Value) *node { - return &node{kind: funcType, action: aNop, rval: v, typ: valueTOf(v.Type())} -} - func genInterfaceWrapper(n *node, typ reflect.Type) func(*frame) reflect.Value { value := genValue(n) if typ == nil || typ.Kind() != reflect.Interface || typ.NumMethod() == 0 || n.typ.cat == valueT { @@ -1178,11 +1144,6 @@ func call(n *node) { // Compute method receiver value. values = append(values, genValueRecv(c0)) method = true - case len(c0.child) > 0 && c0.child[0].typ != nil && isInterfaceSrc(c0.child[0].typ): - recvIndexLater = true - values = append(values, genValueBinRecv(c0, &receiver{node: c0.child[0]})) - value = genValueBinMethodOnInterface(n, value) - method = true case c0.action == aMethod: // Add a place holder for interface method receiver. values = append(values, nil) @@ -1244,7 +1205,7 @@ func call(n *node) { case isInterfaceBin(arg): values = append(values, genInterfaceWrapper(c, arg.rtype)) case isFuncSrc(arg): - values = append(values, genValueNode(c)) + values = append(values, genFuncValue(c)) default: values = append(values, genValue(c)) } @@ -1309,21 +1270,37 @@ func call(n *node) { var ok bool bf := value(f) + if def, ok = bf.Interface().(*node); ok { bf = def.rval } // Call bin func if defined if bf.IsValid() { + var callf func([]reflect.Value) []reflect.Value + + // Lambda definitions are necessary here. Due to reflect internals, + // having `callf = bf.Call` or `callf = bf.CallSlice` does not work. + //nolint:gocritic + if hasVariadicArgs { + callf = func(in []reflect.Value) []reflect.Value { return bf.CallSlice(in) } + } else { + callf = func(in []reflect.Value) []reflect.Value { return bf.Call(in) } + } + + if method && len(values) > bf.Type().NumIn() { + // The receiver is already passed in the function wrapper, skip it. + values = values[1:] + } in := make([]reflect.Value, len(values)) for i, v := range values { in[i] = v(f) } if goroutine { - go bf.Call(in) + go callf(in) return tnext } - out := bf.Call(in) + out := callf(in) for i, v := range rvalues { if v != nil { v(f).Set(out[i]) @@ -1557,8 +1534,6 @@ func callBin(n *node) { } switch { - case isFuncSrc(c.typ): - values = append(values, genFunctionWrapper(c)) case isEmptyInterface(c.typ): values = append(values, genValue(c)) case isInterfaceSrc(c.typ): @@ -1910,8 +1885,10 @@ func getIndexMap2(n *node) { const fork = true // Duplicate frame in frame.clone(). +// getFunc compiles a closure function generator for anonymous functions. func getFunc(n *node) { - dest := genValue(n) + i := n.findex + l := n.level next := getExec(n.tnext) n.exec = func(f *frame) bltn { @@ -1919,7 +1896,42 @@ func getFunc(n *node) { nod := *n nod.val = &nod nod.frame = fr - dest(f).Set(reflect.ValueOf(&nod)) + def := &nod + numRet := len(def.typ.ret) + + fct := reflect.MakeFunc(nod.typ.TypeOf(), func(in []reflect.Value) []reflect.Value { + // Allocate and init local frame. All values to be settable and addressable. + fr2 := newFrame(fr, len(def.types), fr.runid()) + d := fr2.data + for i, t := range def.types { + d[i] = reflect.New(t).Elem() + } + d = d[numRet:] + + // Copy function input arguments in local frame. + for i, arg := range in { + if i >= len(d) { + // In case of unused arg, there may be not even a frame entry allocated, just skip. + break + } + typ := def.typ.arg[i] + switch { + case isEmptyInterface(typ) || typ.TypeOf() == valueInterfaceType: + d[i].Set(arg) + case isInterfaceSrc(typ): + d[i].Set(reflect.ValueOf(valueInterface{value: arg.Elem()})) + default: + d[i].Set(arg) + } + } + + // Interpreter code execution. + runCfg(def.child[3].start, fr2, def, n) + + return fr2.data[:numRet] + }) + + getFrame(f, l).data[i] = fct return next } } @@ -1935,7 +1947,7 @@ func getMethod(n *node) { nod.val = &nod nod.recv = n.recv nod.frame = fr - getFrame(f, l).data[i] = reflect.ValueOf(&nod) + getFrame(f, l).data[i] = genFuncValue(&nod)(f) return next } } @@ -2002,7 +2014,7 @@ func getMethodByName(n *node) { nod.val = &nod nod.recv = &receiver{nil, val.value, li} nod.frame = fr - getFrame(f, l).data[i] = reflect.ValueOf(&nod) + getFrame(f, l).data[i] = genFuncValue(&nod)(f) return next } } @@ -2390,8 +2402,6 @@ func _return(n *node) { } else { values[i] = genValue(c) } - case funcT: - values[i] = genValue(c) case interfaceT: if len(t.field) == 0 { // empty interface case. @@ -2702,7 +2712,7 @@ func doComposite(n *node, hasType bool, keyed bool) { switch { case val.typ.cat == nilT: values[fieldIndex] = func(*frame) reflect.Value { return reflect.New(rft).Elem() } - case isFuncSrc(val.typ): + case isNamedFuncSrc(val.typ): values[fieldIndex] = genValueAsFunctionWrapper(val) case isInterfaceSrc(ft) && (!isEmptyInterface(ft) || len(val.typ.method) > 0): values[fieldIndex] = genValueInterface(val) @@ -3845,11 +3855,7 @@ func slice0(n *node) { func isNil(n *node) { var value func(*frame) reflect.Value c0 := n.child[0] - if isFuncSrc(c0.typ) { - value = genValueAsFunctionWrapper(c0) - } else { - value = genValue(c0) - } + value = genValue(c0) typ := n.typ.concrete().TypeOf() isInterface := n.typ.TypeOf().Kind() == reflect.Interface tnext := getExec(n.tnext) @@ -3934,11 +3940,7 @@ func isNil(n *node) { func isNotNil(n *node) { var value func(*frame) reflect.Value c0 := n.child[0] - if isFuncSrc(c0.typ) { - value = genValueAsFunctionWrapper(c0) - } else { - value = genValue(c0) - } + value = genValue(c0) typ := n.typ.concrete().TypeOf() isInterface := n.typ.TypeOf().Kind() == reflect.Interface tnext := getExec(n.tnext) diff --git a/interp/type.go b/interp/type.go index cf0dee170..ded5c875e 100644 --- a/interp/type.go +++ b/interp/type.go @@ -1901,6 +1901,7 @@ type refTypeContext struct { // "top-level" point. rect *itype rebuilding bool + slevel int } // Clone creates a copy of the ref type context. @@ -2040,11 +2041,19 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type { case mapT: t.rtype = reflect.MapOf(t.key.refType(ctx), t.val.refType(ctx)) case ptrT: - t.rtype = reflect.PtrTo(t.val.refType(ctx)) + rt := t.val.refType(ctx) + if rt == unsafe2.DummyType && ctx.slevel > 1 { + // We have a pointer to a recursive struct which is not yet fully computed. + // Return it but do not yet store it in rtype, so the complete version can + // be stored in future. + return reflect.PtrTo(rt) + } + t.rtype = reflect.PtrTo(rt) case structT: if t.name != "" { ctx.defined[name] = t } + ctx.slevel++ var fields []reflect.StructField for i, f := range t.field { field := reflect.StructField{ @@ -2061,6 +2070,7 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type { } } } + ctx.slevel-- fieldFix := []int{} // Slice of field indices to fix for recursivity. t.rtype = reflect.StructOf(fields) if ctx.isComplete() { @@ -2116,8 +2126,6 @@ func (t *itype) frameType() (r reflect.Type) { r = reflect.ArrayOf(t.length, t.val.frameType()) case sliceT, variadicT: r = reflect.SliceOf(t.val.frameType()) - case funcT: - r = reflect.TypeOf((*node)(nil)) case interfaceT: if len(t.field) == 0 { // empty interface, do not wrap it @@ -2241,16 +2249,6 @@ func constToString(v reflect.Value) string { return constant.StringVal(c) } -func defRecvType(n *node) *itype { - if n.kind != funcDecl || len(n.child[0].child) == 0 { - return nil - } - if r := n.child[0].child[0].lastChild(); r != nil { - return r.typ - } - return nil -} - func wrappedType(n *node) *itype { if n.typ.cat != valueT { return nil @@ -2293,6 +2291,10 @@ func isGeneric(t *itype) bool { return t.cat == funcT && t.node != nil && len(t.node.child) > 0 && len(t.node.child[0].child) > 0 } +func isNamedFuncSrc(t *itype) bool { + return isFuncSrc(t) && t.node.anc.kind == funcDecl +} + func isFuncSrc(t *itype) bool { return t.cat == funcT || (t.cat == aliasT && isFuncSrc(t.val)) } diff --git a/interp/value.go b/interp/value.go index 7ecf484a4..93ff0f5f2 100644 --- a/interp/value.go +++ b/interp/value.go @@ -40,90 +40,15 @@ func valueOf(data []reflect.Value, i int) reflect.Value { return data[i] } -func genValueBinMethodOnInterface(n *node, defaultGen func(*frame) reflect.Value) func(*frame) reflect.Value { - if n == nil || n.child == nil || n.child[0] == nil || - n.child[0].child == nil || n.child[0].child[0] == nil { - return defaultGen - } - c0 := n.child[0] - if c0.child[1] == nil || c0.child[1].ident == "" { - return defaultGen - } - value0 := genValue(c0.child[0]) - - return func(f *frame) reflect.Value { - v := value0(f) - var nod *node - - for v.IsValid() { - // Traverse interface indirections to find out concrete type. - vi, ok := v.Interface().(valueInterface) - if !ok { - break - } - v = vi.value - nod = vi.node - } - - if nod == nil || nod.typ.rtype == nil { - return defaultGen(f) - } - - // Try to get the bin method, if it doesnt exist, fall back to - // the default generator function. - meth, ok := nod.typ.rtype.MethodByName(c0.child[1].ident) - if !ok { - return defaultGen(f) - } - - return meth.Func - } -} - -func genValueRecvIndirect(n *node) func(*frame) reflect.Value { - vr := genValueRecv(n) - return func(f *frame) reflect.Value { - v := vr(f) - if vi, ok := v.Interface().(valueInterface); ok { - return vi.value - } - return v.Elem() - } -} - func genValueRecv(n *node) func(*frame) reflect.Value { - v := genValue(n.recv.node) - fi := n.recv.index - - if len(fi) == 0 { - return v - } - - return func(f *frame) reflect.Value { - r := v(f) - if r.Kind() == reflect.Ptr { - r = r.Elem() - } - return r.FieldByIndex(fi) - } -} - -func genValueBinRecv(n *node, recv *receiver) func(*frame) reflect.Value { - value := genValue(n) - binValue := genValue(recv.node) - - v := func(f *frame) reflect.Value { - if def, ok := value(f).Interface().(*node); ok { - if def != nil && def.recv != nil && def.recv.val.IsValid() { - return def.recv.val - } - } - - ival, _ := binValue(f).Interface().(valueInterface) - return ival.value + var v func(*frame) reflect.Value + if n.recv.node == nil { + v = func(*frame) reflect.Value { return n.recv.val } + } else { + v = genValue(n.recv.node) } + fi := n.recv.index - fi := recv.index if len(fi) == 0 { return v } @@ -146,6 +71,9 @@ func genValueAsFunctionWrapper(n *node) func(*frame) reflect.Value { if v.IsNil() { return reflect.New(typ).Elem() } + if v.Kind() == reflect.Func { + return v + } vn, ok := v.Interface().(*node) if ok && vn.rval.Kind() == reflect.Func { // The node value is already a callable func, no need to wrap it. @@ -221,9 +149,7 @@ func genDestValue(typ *itype, n *node) func(*frame) reflect.Value { switch { case isInterfaceSrc(typ) && (!isEmptyInterface(typ) || len(n.typ.method) > 0): return genValueInterface(n) - case isFuncSrc(typ) && (n.typ.cat == valueT || n.typ.cat == nilT): - return genValueNode(n) - case typ.cat == valueT && isFuncSrc(n.typ): + case isNamedFuncSrc(n.typ): return genFunctionWrapper(n) case isInterfaceBin(typ): return genInterfaceWrapper(n, typ.rtype) @@ -237,6 +163,17 @@ func genDestValue(typ *itype, n *node) func(*frame) reflect.Value { return genValue(n) } +func genFuncValue(n *node) func(*frame) reflect.Value { + value := genValue(n) + return func(f *frame) reflect.Value { + v := value(f) + if nod, ok := v.Interface().(*node); ok { + return genFunctionWrapper(nod)(f) + } + return v + } +} + func genValueArray(n *node) func(*frame) reflect.Value { value := genValue(n) // dereference array pointer, to support array operations on array pointer @@ -419,18 +356,6 @@ func genValueInterfaceValue(n *node) func(*frame) reflect.Value { } } -func genValueNode(n *node) func(*frame) reflect.Value { - value := genValue(n) - - return func(f *frame) reflect.Value { - v := value(f) - if _, ok := v.Interface().(*node); ok { - return v - } - return reflect.ValueOf(&node{rval: v}) - } -} - func vInt(v reflect.Value) (i int64) { if c := vConstantValue(v); c != nil { i, _ = constant.Int64Val(constant.ToInt(c))