diff --git a/dbfixture/fixture.go b/dbfixture/fixture.go index ad77bf225..3d963d726 100644 --- a/dbfixture/fixture.go +++ b/dbfixture/fixture.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" "text/template" + "text/template/parse" "time" "gopkg.in/yaml.v3" @@ -249,13 +250,11 @@ func (f *Fixture) decodeField(strct reflect.Value, field *schema.Field, value *y } if tplRE.MatchString(value.Value) { - str, err := f.eval(value.Value) + src, err := f.eval(value.Value) if err != nil { return err } - if str != value.Value { - return field.ScanValue(strct, str) - } + return field.ScanValue(strct, src) } if v, ok := iface.(yaml.Unmarshaler); ok { @@ -310,21 +309,100 @@ func (f *Fixture) truncateTable(ctx context.Context, table *schema.Table) error return nil } -func (f *Fixture) eval(templ string) (string, error) { +func (f *Fixture) eval(templ string) (interface{}, error) { + if v, ok := f.evalFuncCall(templ); ok { + return v, nil + } + tpl, err := template.New("").Funcs(f.funcMap).Parse(templ) if err != nil { - return "", err + return nil, err } var buf bytes.Buffer if err := tpl.Execute(&buf, f.modelRows); err != nil { - return "", err + return nil, err } return buf.String(), nil } +func (f *Fixture) evalFuncCall(templ string) (interface{}, bool) { + tree, err := parse.Parse("", templ, "{{", "}}", f.funcMap) + if err != nil { + return nil, false + } + + root := tree[""].Root + if len(root.Nodes) != 1 { + return nil, false + } + + action, ok := root.Nodes[0].(*parse.ActionNode) + if !ok { + return nil, false + } + + if len(action.Pipe.Cmds) != 1 { + return nil, false + } + + args := action.Pipe.Cmds[0].Args + if len(args) == 0 { + return nil, false + } + + funcName, ok := args[0].(*parse.IdentifierNode) + if !ok { + return nil, false + } + + fn, ok := f.funcMap[funcName.Ident] + if !ok { + return nil, false + } + + fnValue := reflect.ValueOf(fn) + fnType := fnValue.Type() + if fnType.NumOut() != 1 { + return nil, false + } + + args = args[1:] + if len(args) != fnType.NumIn() { + return nil, false + } + argValues := make([]reflect.Value, len(args)) + + for i, node := range args { + switch node := node.(type) { + case *parse.StringNode: + argValues[i] = reflect.ValueOf(node.Text) + case *parse.NumberNode: + switch { + case node.IsInt: + argValues[i] = reflect.ValueOf(node.Int64) + case node.IsUint: + argValues[i] = reflect.ValueOf(node.Uint64) + case node.IsFloat: + argValues[i] = reflect.ValueOf(node.Float64) + case node.IsComplex: + argValues[i] = reflect.ValueOf(node.Complex128) + default: + argValues[i] = reflect.ValueOf(node.Text) + } + case *parse.BoolNode: + argValues[i] = reflect.ValueOf(node.True) + default: + return nil, false + } + } + + out := fnValue.Call(argValues) + return out[0].Interface(), true +} + type fixtureData struct { Model string `yaml:"model"` Rows []row `yaml:"rows"`