diff --git a/expr_test.go b/expr_test.go index ced182a9..574f45b2 100644 --- a/expr_test.go +++ b/expr_test.go @@ -624,6 +624,7 @@ func TestExpr(t *testing.T) { date := time.Date(2017, time.October, 23, 18, 30, 0, 0, time.UTC) oneDay, _ := time.ParseDuration("24h") timeNowPlusOneDay := date.Add(oneDay) + mode := mock.ModeEnum(1) env := mock.Env{ Embed: mock.Embed{}, @@ -643,6 +644,7 @@ func TestExpr(t *testing.T) { IntPtr: nil, IntPtrPtr: nil, StringPtr: nil, + ModePtr: &mode, Foo: mock.Foo{ Value: "foo", Bar: mock.Bar{ @@ -1291,6 +1293,11 @@ func TestExpr(t *testing.T) { `1 < 2 < 3 == true`, true, }, + { + // Test pointer dereferencing with custom integer type + `ModePtr == 1`, + true, + }, } for _, tt := range tests { @@ -2736,3 +2743,20 @@ func TestExpr_env_types_map_error(t *testing.T) { _, err = expr.Run(program, envTypes) require.Error(t, err) } + +func Example_pointerDereference() { + type Mode int + mode := Mode(1) + env := map[string]any{ + "Mode": &mode, + } + + output, err := expr.Eval(`Mode == 1`, env) + if err != nil { + fmt.Printf("err: %v", err) + return + } + + fmt.Printf("%v", output) + // Output: true +} diff --git a/test/mock/mock.go b/test/mock/mock.go index fc91652f..acd41805 100644 --- a/test/mock/mock.go +++ b/test/mock/mock.go @@ -9,6 +9,8 @@ import ( "github.com/expr-lang/expr/ast" ) +type ModeEnum int + type Env struct { Embed Ambiguous string @@ -51,6 +53,7 @@ type Env struct { Time time.Time TimePlusDay time.Time Duration time.Duration + ModePtr *ModeEnum } func (p Env) FuncFoo(_ Foo) int { diff --git a/vm/runtime/helpers/main.go b/vm/runtime/helpers/main.go index 54a4fc23..56a0dd72 100644 --- a/vm/runtime/helpers/main.go +++ b/vm/runtime/helpers/main.go @@ -67,9 +67,33 @@ func cases(op string, xs ...[]string) string { echo := func(s string, xs ...any) { out += fmt.Sprintf(s, xs...) + "\n" } + + // Only generate pointer cases for equality operations + if op == "==" { + for _, a := range types { + // Handle *T cases + echo(`case *%v:`, a) + echo(`switch y := b.(type) {`) + echo(`case %v:`, a) + echo(`return *x == y`) + echo(`case *%v:`, a) + echo(`return *x == *y`) + echo(`}`) + } + } + + // Generate regular cases for _, a := range types { echo(`case %v:`, a) echo(`switch y := b.(type) {`) + + // Add pointer case for equality + if op == "==" { + echo(`case *%v:`, a) + echo(`return x == *y`) + } + + // Add regular type cases for _, b := range types { t := "int" if isDuration(a) || isDuration(b) { @@ -172,6 +196,35 @@ func Equal(a, b interface{}) bool { return x == y } } + + // Handle unknown pointer types and custom types + va := reflect.ValueOf(a) + vb := reflect.ValueOf(b) + + // Handle pointers to unknown types + if va.Kind() == reflect.Ptr { + if !va.IsValid() || va.IsNil() { + return vb.Kind() == reflect.Ptr && (!vb.IsValid() || vb.IsNil()) + } + va = va.Elem() + } + if vb.Kind() == reflect.Ptr { + if !vb.IsValid() || vb.IsNil() { + return va.Kind() == reflect.Ptr && (!va.IsValid() || va.IsNil()) + } + vb = vb.Elem() + } + + // Handle custom integer types by converting to int64 + if va.IsValid() && vb.IsValid() { + ka := va.Kind() + kb := vb.Kind() + if (ka == reflect.Int || ka == reflect.Int8 || ka == reflect.Int16 || ka == reflect.Int32 || ka == reflect.Int64) && + (kb == reflect.Int || kb == reflect.Int8 || kb == reflect.Int16 || kb == reflect.Int32 || kb == reflect.Int64) { + return va.Int() == vb.Int() + } + } + if IsNil(a) && IsNil(b) { return true } diff --git a/vm/runtime/helpers[generated].go b/vm/runtime/helpers[generated].go index d950f111..8b7d7621 100644 --- a/vm/runtime/helpers[generated].go +++ b/vm/runtime/helpers[generated].go @@ -10,8 +10,94 @@ import ( func Equal(a, b interface{}) bool { switch x := a.(type) { + case *uint: + switch y := b.(type) { + case uint: + return *x == y + case *uint: + return *x == *y + } + case *uint8: + switch y := b.(type) { + case uint8: + return *x == y + case *uint8: + return *x == *y + } + case *uint16: + switch y := b.(type) { + case uint16: + return *x == y + case *uint16: + return *x == *y + } + case *uint32: + switch y := b.(type) { + case uint32: + return *x == y + case *uint32: + return *x == *y + } + case *uint64: + switch y := b.(type) { + case uint64: + return *x == y + case *uint64: + return *x == *y + } + case *int: + switch y := b.(type) { + case int: + return *x == y + case *int: + return *x == *y + } + case *int8: + switch y := b.(type) { + case int8: + return *x == y + case *int8: + return *x == *y + } + case *int16: + switch y := b.(type) { + case int16: + return *x == y + case *int16: + return *x == *y + } + case *int32: + switch y := b.(type) { + case int32: + return *x == y + case *int32: + return *x == *y + } + case *int64: + switch y := b.(type) { + case int64: + return *x == y + case *int64: + return *x == *y + } + case *float32: + switch y := b.(type) { + case float32: + return *x == y + case *float32: + return *x == *y + } + case *float64: + switch y := b.(type) { + case float64: + return *x == y + case *float64: + return *x == *y + } case uint: switch y := b.(type) { + case *uint: + return x == *y case uint: return int(x) == int(y) case uint8: @@ -39,6 +125,8 @@ func Equal(a, b interface{}) bool { } case uint8: switch y := b.(type) { + case *uint8: + return x == *y case uint: return int(x) == int(y) case uint8: @@ -66,6 +154,8 @@ func Equal(a, b interface{}) bool { } case uint16: switch y := b.(type) { + case *uint16: + return x == *y case uint: return int(x) == int(y) case uint8: @@ -93,6 +183,8 @@ func Equal(a, b interface{}) bool { } case uint32: switch y := b.(type) { + case *uint32: + return x == *y case uint: return int(x) == int(y) case uint8: @@ -120,6 +212,8 @@ func Equal(a, b interface{}) bool { } case uint64: switch y := b.(type) { + case *uint64: + return x == *y case uint: return int(x) == int(y) case uint8: @@ -147,6 +241,8 @@ func Equal(a, b interface{}) bool { } case int: switch y := b.(type) { + case *int: + return x == *y case uint: return int(x) == int(y) case uint8: @@ -174,6 +270,8 @@ func Equal(a, b interface{}) bool { } case int8: switch y := b.(type) { + case *int8: + return x == *y case uint: return int(x) == int(y) case uint8: @@ -201,6 +299,8 @@ func Equal(a, b interface{}) bool { } case int16: switch y := b.(type) { + case *int16: + return x == *y case uint: return int(x) == int(y) case uint8: @@ -228,6 +328,8 @@ func Equal(a, b interface{}) bool { } case int32: switch y := b.(type) { + case *int32: + return x == *y case uint: return int(x) == int(y) case uint8: @@ -255,6 +357,8 @@ func Equal(a, b interface{}) bool { } case int64: switch y := b.(type) { + case *int64: + return x == *y case uint: return int(x) == int(y) case uint8: @@ -282,6 +386,8 @@ func Equal(a, b interface{}) bool { } case float32: switch y := b.(type) { + case *float32: + return x == *y case uint: return float64(x) == float64(y) case uint8: @@ -309,6 +415,8 @@ func Equal(a, b interface{}) bool { } case float64: switch y := b.(type) { + case *float64: + return x == *y case uint: return float64(x) == float64(y) case uint8: @@ -693,6 +801,35 @@ func Equal(a, b interface{}) bool { return x == y } } + + // Handle unknown pointer types and custom types + va := reflect.ValueOf(a) + vb := reflect.ValueOf(b) + + // Handle pointers to unknown types + if va.Kind() == reflect.Ptr { + if !va.IsValid() || va.IsNil() { + return vb.Kind() == reflect.Ptr && (!vb.IsValid() || vb.IsNil()) + } + va = va.Elem() + } + if vb.Kind() == reflect.Ptr { + if !vb.IsValid() || vb.IsNil() { + return va.Kind() == reflect.Ptr && (!va.IsValid() || va.IsNil()) + } + vb = vb.Elem() + } + + // Handle custom integer types by converting to int64 + if va.IsValid() && vb.IsValid() { + ka := va.Kind() + kb := vb.Kind() + if (ka == reflect.Int || ka == reflect.Int8 || ka == reflect.Int16 || ka == reflect.Int32 || ka == reflect.Int64) && + (kb == reflect.Int || kb == reflect.Int8 || kb == reflect.Int16 || kb == reflect.Int32 || kb == reflect.Int64) { + return va.Int() == vb.Int() + } + } + if IsNil(a) && IsNil(b) { return true } diff --git a/vm/runtime/helpers_test.go b/vm/runtime/helpers_test.go index b0ef5340..d2d81b40 100644 --- a/vm/runtime/helpers_test.go +++ b/vm/runtime/helpers_test.go @@ -8,6 +8,8 @@ import ( "github.com/expr-lang/expr/vm/runtime" ) +type CustomInt int + var tests = []struct { name string a, b any @@ -33,6 +35,11 @@ var tests = []struct { {"deep []any != []any", []any{[]int{1}, 2, []any{"3", "42"}}, []any{[]any{1}, 2, []string{"3"}}, false}, {"map[string]any == map[string]any", map[string]any{"a": 1}, map[string]any{"a": 1}, true}, {"map[string]any != map[string]any", map[string]any{"a": 1}, map[string]any{"a": 1, "b": 2}, false}, + {name: "*CustomInt == int", a: func() any { x := CustomInt(1); return &x }(), b: 1, want: true}, + {name: "int == *CustomInt", a: 1, b: func() any { x := CustomInt(1); return &x }(), want: true}, + {name: "*CustomInt != int", a: func() any { x := CustomInt(2); return &x }(), b: 1, want: false}, + {name: "*CustomInt == *CustomInt", a: func() any { x := CustomInt(1); return &x }(), b: func() any { x := CustomInt(1); return &x }(), want: true}, + {name: "*CustomInt != *CustomInt", a: func() any { x := CustomInt(1); return &x }(), b: func() any { x := CustomInt(2); return &x }(), want: false}, } func TestEqual(t *testing.T) { @@ -44,7 +51,6 @@ func TestEqual(t *testing.T) { assert.Equal(t, tt.want, got, "Equal(%v, %v) = %v; want %v", tt.b, tt.a, got, tt.want) }) } - } func BenchmarkEqual(b *testing.B) {