diff --git a/docs/Exprgen.md b/docs/Exprgen.md new file mode 100644 index 000000000..9e54021b8 --- /dev/null +++ b/docs/Exprgen.md @@ -0,0 +1,15 @@ +# Exprgen + +## Install +``` +go install github.com/antonmedv/expr/exprgen +``` +## Usage +Fetch methods generates for all struct/map/array/string named types(exception is map types with unnamed not basic key type like `map[struct{...}]int`). + +To generate just call exprgen with pkg paths as arguments: +``` +exprgen pkg1 pkg2 ... +``` + +After call, file `*pkg_name*_exprgen.go` will be created in each packages from arguments. diff --git a/docs/Optimizations.md b/docs/Optimizations.md index 1bbbef12d..65cc6419b 100644 --- a/docs/Optimizations.md +++ b/docs/Optimizations.md @@ -116,4 +116,18 @@ func main() { } ``` +## Reduced use of reflect + +To fetch fields from struct, values from map, get by indexes expr uses reflect package. +Envs can implement vm.Fetcher interface, to avoid use reflect: +```go +type Fetcher interface { + Fetch(interface{}) interface{} +} +``` +When you need to fetch a field, the method will be used instead reflect functions. +If the field is not found, Fetch must return nil. +To generate Fetch for your types, use [Exprgen](Exprgen.md). + + * [Contents](README.md) diff --git a/exprgen/exprgen.go b/exprgen/exprgen.go new file mode 100644 index 000000000..e1e70f5d7 --- /dev/null +++ b/exprgen/exprgen.go @@ -0,0 +1,351 @@ +package main + +import ( + "flag" + "fmt" + "go/ast" + "go/format" + "go/importer" + "go/parser" + "go/token" + "go/types" + "io/fs" + "io/ioutil" + "os" + "path/filepath" + "sort" + "strings" +) + +const exprgenSuffix = "_exprgen.go" + +func main() { + flag.Parse() + + filenames := flag.Args() + if len(filenames) == 0 { + flag.Usage() + os.Exit(1) + } + + for _, filename := range filenames { + if err := generate(filename); err != nil { + fmt.Fprintf(os.Stderr, "generate '%s' error: %s", filename, err.Error()) + os.Exit(2) + } + } +} + +func generate(filename string) error { + fi, err := os.Stat(filename) + if err != nil { + return fmt.Errorf("stat err: %w", err) + } + + if !fi.IsDir() { + return fmt.Errorf("filename must be dir") + } + tfs := token.NewFileSet() + packages, err := parser.ParseDir(tfs, filename, func(info fs.FileInfo) bool { + return !strings.HasSuffix(info.Name(), exprgenSuffix) + }, parser.ParseComments) + if err != nil { + return fmt.Errorf("parse dir error: %w", err) + } + + typesChecker := types.Config{ + Importer: importer.ForCompiler(tfs, "source", nil), + } + + for name, pkg := range packages { + if strings.HasSuffix(name, "_test") { + continue + } + + files := make([]*ast.File, 0, len(pkg.Files)) + for _, f := range pkg.Files { + files = append(files, f) + } + + packageTypes, err := typesChecker.Check(name, tfs, files, nil) + if err != nil { + return fmt.Errorf("types check error: %w", err) + } + + b, err := fileData(name, packageTypes) + if err != nil { + return err + } + + err = ioutil.WriteFile(filepath.Join(filename, name+exprgenSuffix), b, 0644) + if err != nil { + return err + } + } + + return nil +} + +func fileData(pkgName string, pkg *types.Package) ([]byte, error) { + var data string + echo := func(s string, xs ...interface{}) { + data += fmt.Sprintf(s, xs...) + "\n" + } + echoRaw := func(s string) { + data += fmt.Sprint(s) + "\n" + } + + echo(`// Code generated by exprgen. DO NOT EDIT.`) + echo(``) + echo(`package ` + pkgName) + echo(``) + echo(`--imports`) + echo(``) + + echoRaw(`func toInt(a interface{}) int { + switch x := a.(type) { + case float32: + return int(x) + case float64: + return int(x) + + case int: + return x + case int8: + return int(x) + case int16: + return int(x) + case int32: + return int(x) + case int64: + return int(x) + + case uint: + return int(x) + case uint8: + return int(x) + case uint16: + return int(x) + case uint32: + return int(x) + case uint64: + return int(x) + + default: + panic(fmt.Sprintf("invalid operation: int(%T)", x)) + } + }`) + echo(``) + + imports := make(map[string]string) + + scope := pkg.Scope() + for _, objectName := range scope.Names() { + obj := scope.Lookup(objectName) + namedType, ok := obj.Type().(*types.Named) + if !ok { + continue + } + + recvName := "v" + for i := 0; i < namedType.NumMethods(); i++ { + method := namedType.Method(i) + signature := method.Type().(*types.Signature) + recv := signature.Recv() + if recv != nil && recv.Name() != "" { + recvName = recv.Name() + break + } + } + + switch t := namedType.Underlying().(type) { + case *types.Basic: + if t.Kind() != types.String { + break + } + + echo("func (%s %s) Fetch(i interface{}) interface{} {", recvName, objectName) + echo("return %s[toInt(i)]", recvName) + echo("}") + case *types.Slice, *types.Array: + echo("func (%s %s) Fetch(i interface{}) interface{} {", recvName, objectName) + echo("return %s[toInt(i)]", recvName) + echo("}") + case *types.Map: + echo("func (%s %s) Fetch(i interface{}) interface{} {", recvName, objectName) + key := t.Key() + + numericCases := []string{ + "int", + "int8", + "int16", + "int32", + "int64", + "uint", + "uint8", + "uint16", + "uint32", + "uint64", + "uintptr", + "float32", + "float64", + } + + switch k := key.(type) { + case *types.Named: + objKey := k.Obj() + keyName := objKey.Name() + if objKey.Pkg().Path() != pkg.Path() { + path := objKey.Pkg().Path() + name := objKey.Pkg().Name() + for imports[name] != "" && path != imports[name] { + name = name + "1" + } + imports[name] = path + keyName = name + "." + keyName + } + + echo(`switch _x_i := i.(type) {`) + echo("case %s:", keyName) + echo("return %s[_x_i]", recvName) + if basicKey, ok := k.Underlying().(*types.Basic); ok { + if basicKey.Info()&types.IsNumeric != 0 { + for _, c := range numericCases { + echo("case %s:", c) + echo("return %s[%s(_x_i)]", recvName, keyName) + } + } + if basicKey.Info()&types.IsString != 0 { + echo(`case string:`) + echo("return %s[%s(_x_i)]", recvName, keyName) + echo("default:") + imports["fmt"] = "fmt" + echo("return %s[%s(fmt.Sprint(i))]", recvName, keyName) + } + } + echo(`}`) + case *types.Basic: + keyName := k.String() + echo(`switch _x_i := i.(type) {`) + echo("case %s:", keyName) + echo("return %s[_x_i]", recvName) + if k.Info()&types.IsNumeric != 0 { + for _, c := range numericCases { + if c == keyName { + continue + } + echo("case %s:", c) + echo("return %s[%s(_x_i)]", recvName, keyName) + } + } + + if k.Info()&types.IsString != 0 { + echo("default:") + imports["fmt"] = "fmt" + echo("return %s[%s(fmt.Sprint(i))]", recvName, keyName) + } + + echo(`}`) + } + echo("return nil") + echo(`}`) + case *types.Struct: + echo("func (%s %s) Fetch(i interface{}) interface{} {", recvName, objectName) + + fields := make(map[string]string) + collectStruct(recvName, t, func(c string, r string) { + if _, ok := fields[c]; ok { + fields[c] = "-" + } + fields[c] = r + }) + + keys := make([]string, 0, len(fields)) + for c, r := range fields { + if r == "-" { + continue + } + keys = append(keys, c) + } + sort.Strings(keys) + imports["fmt"] = "fmt" + + echo(`var string_i string`) + echo(`if s, ok := i.(string); ok {`) + echo(`string_i = s`) + echo(`} else {`) + echo(`string_i = fmt.Sprint(i)`) + echo(`}`) + + echo(`switch string_i {`) + for _, key := range keys { + echo("case \"%s\":", key) + echo("return %s", fields[key]) + } + echo(`}`) + echo(`return nil`) + echo(`}`) + } + } + + importsString := "import (\n" + for k, v := range imports { + importsString += k + "\"" + v + "\"\n" + } + importsString += ")" + data = strings.Replace(data, "--imports", importsString, 1) + + return format.Source([]byte(data)) +} + +func collectStruct(recv string, t *types.Struct, collect func(string, string), skippedNames ...string) { + fieldNames := make([]string, 0, t.NumFields()) + for i := 0; i < t.NumFields(); i++ { + fieldNames = append(fieldNames, t.Field(i).Name()) + } + + for i := 0; i < t.NumFields(); i++ { + v := t.Field(i) + if !v.Exported() || contains(skippedNames, v.Name()) { + continue + } + + collect(v.Name(), recv+"."+v.Name()) + + if v.Embedded() { + tt := v.Type() + for dereference(tt) != underlying(tt) { + tt = dereference(tt) + tt = underlying(tt) + } + + switch vt := tt.(type) { + case *types.Struct: + collectStruct(recv+"."+v.Name(), vt, collect, fieldNames...) + } + } + } +} + +func dereference(t types.Type) types.Type { + if p, ok := t.(*types.Pointer); ok { + return dereference(p.Elem()) + } + return t +} + +func underlying(t types.Type) types.Type { + if t != t.Underlying() { + return underlying(t.Underlying()) + } + return t +} + +func contains(arr []string, s string) bool { + for _, e := range arr { + if e == s { + return true + } + } + return false +} diff --git a/go.sum b/go.sum index a43e72bf7..af0ea8e7a 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,7 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/vm/runtime.go b/vm/runtime.go index 926563664..2e0091314 100644 --- a/vm/runtime.go +++ b/vm/runtime.go @@ -15,7 +15,22 @@ type Call struct { type Scope map[string]interface{} +type Fetcher interface { + Fetch(interface{}) interface{} +} + func fetch(from, i interface{}, nilsafe bool) interface{} { + if fetcher, ok := from.(Fetcher); ok { + value := fetcher.Fetch(i) + if value != nil { + return value + } + if !nilsafe { + panic(fmt.Sprintf("cannot fetch %v from %T", i, from)) + } + return nil + } + v := reflect.ValueOf(from) kind := v.Kind()