diff --git a/eval_test.go b/eval_test.go index 9a8dfd754..5c252f3d3 100644 --- a/eval_test.go +++ b/eval_test.go @@ -586,19 +586,19 @@ func TestEval_panic(t *testing.T) { } func TestEval_method(t *testing.T) { - env := testEnv{ + env := &testEnv{ Hello: "hello", World: testWorld{ name: []string{"w", "o", "r", "l", "d"}, }, - testVersion: testVersion{ + testVersion: &testVersion{ version: 2, }, } - input := `Title(Hello) ~ ' ' ~ (CompareVersion(1, 3) ? World.String() : '')` + input := `Title(Hello) ~ Empty() ~ (CompareVersion(1, 3) ? World.String() : '')` - node, err := expr.Parse(input) + node, err := expr.Parse(input, expr.Env(&testEnv{})) fmt.Printf("%#v\n", node) if err != nil { t.Fatal(err) @@ -619,7 +619,7 @@ type testVersion struct { version float64 } -func (c testVersion) CompareVersion(min, max float64) bool { +func (c *testVersion) CompareVersion(min, max float64) bool { return min < c.version && c.version < max } @@ -632,11 +632,15 @@ func (w testWorld) String() string { } type testEnv struct { - testVersion + *testVersion Hello string World testWorld } -func (e testEnv) Title(s string) string { +func (e *testEnv) Title(s string) string { return strings.Title(s) } + +func (e *testEnv) Empty() string { + return " " +} diff --git a/parser.go b/parser.go index 2d0cd0d9f..aa24956c7 100644 --- a/parser.go +++ b/parser.go @@ -140,14 +140,22 @@ func (p *parser) createTypesTable(i interface{}) typesTable { v := reflect.ValueOf(i) t := reflect.TypeOf(i) - t = dereference(t) - if t == nil { - return types + d := t + if t.Kind() == reflect.Ptr { + d = t.Elem() } - switch t.Kind() { + switch d.Kind() { case reflect.Struct: - types = p.fromStruct(t) + types = p.fieldsFromStruct(d) + + // Methods of struct should be gathered from original struct with pointer, + // as methods maybe declared on pointer receiver. Also this method retrieves + // all embedded structs methods as well, no need to recursion. + for i := 0; i < t.NumMethod(); i++ { + m := t.Method(i) + types[m.Name] = m.Type + } case reflect.Map: for _, key := range v.MapKeys() { @@ -161,7 +169,7 @@ func (p *parser) createTypesTable(i interface{}) typesTable { return types } -func (p *parser) fromStruct(t reflect.Type) typesTable { +func (p *parser) fieldsFromStruct(t reflect.Type) typesTable { types := make(typesTable) t = dereference(t) if t == nil { @@ -174,18 +182,13 @@ func (p *parser) fromStruct(t reflect.Type) typesTable { f := t.Field(i) if f.Anonymous { - for name, typ := range p.fromStruct(f.Type) { + for name, typ := range p.fieldsFromStruct(f.Type) { types[name] = typ } } types[f.Name] = f.Type } - - for i := 0; i < t.NumMethod(); i++ { - m := t.Method(i) - types[m.Name] = m.Type - } } return types diff --git a/runtime.go b/runtime.go index fc53ffc88..9ee1bf1d6 100644 --- a/runtime.go +++ b/runtime.go @@ -103,9 +103,14 @@ func extract(val interface{}, i interface{}) (interface{}, bool) { func getFunc(val interface{}, i interface{}) (interface{}, bool) { v := reflect.ValueOf(val) - switch v.Kind() { + d := v + if v.Kind() == reflect.Ptr { + d = v.Elem() + } + + switch d.Kind() { case reflect.Map: - value := v.MapIndex(reflect.ValueOf(i)) + value := d.MapIndex(reflect.ValueOf(i)) if value.IsValid() && value.CanInterface() { return value.Interface(), true } @@ -119,11 +124,6 @@ func getFunc(val interface{}, i interface{}) (interface{}, bool) { if value.IsValid() && value.CanInterface() { return value.Interface(), true } - case reflect.Ptr: - value := v.Elem() - if value.IsValid() && value.CanInterface() { - return getFunc(value.Interface(), i) - } } return nil, false }