From ded019d21ec180b7018641732617c59e8932da32 Mon Sep 17 00:00:00 2001 From: Ganesan Karuppasamy Date: Sun, 3 Mar 2024 21:14:32 +0530 Subject: [PATCH] Enable Support for Arrays in Sum, Mean, and Median Functions (#580) --- builtin/builtin.go | 201 ++++++---------------------------------- builtin/builtin_test.go | 13 +++ builtin/lib.go | 154 ++++++++++++++++++++++++------ builtin/validation.go | 38 ++++++++ 4 files changed, 206 insertions(+), 200 deletions(-) create mode 100644 builtin/validation.go diff --git a/builtin/builtin.go b/builtin/builtin.go index fc48e111..7bf377df 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -135,42 +135,21 @@ var Builtins = []*Function{ Name: "ceil", Fast: Ceil, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: - return floatType, nil - } - return anyType, fmt.Errorf("invalid argument for ceil (type %s)", args[0]) + return validateRoundFunc("ceil", args) }, }, { Name: "floor", Fast: Floor, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: - return floatType, nil - } - return anyType, fmt.Errorf("invalid argument for floor (type %s)", args[0]) + return validateRoundFunc("floor", args) }, }, { Name: "round", Fast: Round, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: - return floatType, nil - } - return anyType, fmt.Errorf("invalid argument for floor (type %s)", args[0]) + return validateRoundFunc("round", args) }, }, { @@ -392,185 +371,63 @@ var Builtins = []*Function{ }, { Name: "max", - Func: Max, + Func: func(args ...any) (any, error) { + return minMax("max", runtime.Less, args...) + }, Validate: func(args []reflect.Type) (reflect.Type, error) { - switch len(args) { - case 0: - return anyType, fmt.Errorf("not enough arguments to call max") - case 1: - if kindName := kind(args[0]); kindName == reflect.Array || kindName == reflect.Slice { - return anyType, nil - } - fallthrough - default: - for _, arg := range args { - switch kind(arg) { - case reflect.Interface, reflect.Array, reflect.Slice: - return anyType, nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: - default: - return anyType, fmt.Errorf("invalid argument for max (type %s)", arg) - } - } - return args[0], nil - } + return validateAggregateFunc("max", args) }, }, { Name: "min", - Func: Min, + Func: func(args ...any) (any, error) { + return minMax("min", runtime.More, args...) + }, Validate: func(args []reflect.Type) (reflect.Type, error) { - switch len(args) { - case 0: - return anyType, fmt.Errorf("not enough arguments to call min") - case 1: - if kindName := kind(args[0]); kindName == reflect.Array || kindName == reflect.Slice { - return anyType, nil - } - fallthrough - default: - for _, arg := range args { - switch kind(arg) { - case reflect.Interface, reflect.Array, reflect.Slice: - return anyType, nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: - default: - return anyType, fmt.Errorf("invalid argument for min (type %s)", arg) - } - } - return args[0], nil - - } + return validateAggregateFunc("min", args) }, }, { Name: "sum", - Func: func(args ...any) (any, error) { - if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot sum %s", v.Kind()) - } - sum := int64(0) - i := 0 - for ; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - sum += it.Int() - } else if it.CanFloat() { - goto float - } else { - return nil, fmt.Errorf("cannot sum %s", it.Kind()) - } - } - return int(sum), nil - float: - fSum := float64(sum) - for ; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - fSum += float64(it.Int()) - } else if it.CanFloat() { - fSum += it.Float() - } else { - return nil, fmt.Errorf("cannot sum %s", it.Kind()) - } - } - return fSum, nil - }, + Func: sum, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot sum %s", args[0]) - } - return anyType, nil + return validateAggregateFunc("sum", args) }, }, { Name: "mean", Func: func(args ...any) (any, error) { - if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot mean %s", v.Kind()) + count, sum, err := mean(args...) + if err != nil { + return nil, err } - if v.Len() == 0 { + if count == 0 { return 0.0, nil } - sum := float64(0) - i := 0 - for ; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - sum += float64(it.Int()) - } else if it.CanFloat() { - sum += it.Float() - } else { - return nil, fmt.Errorf("cannot mean %s", it.Kind()) - } - } - return sum / float64(i), nil + return sum / float64(count), nil }, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot avg %s", args[0]) - } - return floatType, nil + return validateAggregateFunc("mean", args) }, }, { Name: "median", Func: func(args ...any) (any, error) { - if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot median %s", v.Kind()) - } - if v.Len() == 0 { - return 0.0, nil + values, err := median(args...) + if err != nil { + return nil, err } - s := make([]float64, v.Len()) - for i := 0; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - s[i] = float64(it.Int()) - } else if it.CanFloat() { - s[i] = it.Float() - } else { - return nil, fmt.Errorf("cannot median %s", it.Kind()) + if n := len(values); n > 0 { + sort.Float64s(values) + if n%2 == 1 { + return values[n/2], nil } + return (values[n/2-1] + values[n/2]) / 2, nil } - sort.Float64s(s) - if len(s)%2 == 0 { - return (s[len(s)/2-1] + s[len(s)/2]) / 2, nil - } - return s[len(s)/2], nil + return 0.0, nil }, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot median %s", args[0]) - } - return floatType, nil + return validateAggregateFunc("median", args) }, }, { diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index bc1a2e14..aa324c9b 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -85,19 +85,29 @@ func TestBuiltin(t *testing.T) { {`min(1.5, 2.5, 3.5)`, 1.5}, {`min([1, 2, 3])`, 1}, {`min([1.5, 2.5, 3.5])`, 1.5}, + {`min(-1, [1.5, 2.5, 3.5])`, -1}, {`sum(1..9)`, 45}, {`sum([.5, 1.5, 2.5])`, 4.5}, {`sum([])`, 0}, {`sum([1, 2, 3.0, 4])`, 10.0}, + {`sum(10, [1, 2, 3], 1..9)`, 61}, + {`sum(-10, [1, 2, 3, 4])`, 0}, + {`sum(-10.9, [1, 2, 3, 4, 9])`, 8.1}, {`mean(1..9)`, 5.0}, {`mean([.5, 1.5, 2.5])`, 1.5}, {`mean([])`, 0.0}, {`mean([1, 2, 3.0, 4])`, 2.5}, + {`mean(10, [1, 2, 3], 1..9)`, 4.6923076923076925}, + {`mean(-10, [1, 2, 3, 4])`, 0.0}, + {`mean(10.9, 1..9)`, 5.59}, {`median(1..9)`, 5.0}, {`median([.5, 1.5, 2.5])`, 1.5}, {`median([])`, 0.0}, {`median([1, 2, 3])`, 2.0}, {`median([1, 2, 3, 4])`, 2.5}, + {`median(10, [1, 2, 3], 1..9)`, 4.0}, + {`median(-10, [1, 2, 3, 4])`, 2.0}, + {`median(1..5, 4.9)`, 3.5}, {`toJSON({foo: 1, bar: 2})`, "{\n \"bar\": 2,\n \"foo\": 1\n}"}, {`fromJSON("[1, 2, 3]")`, []any{1.0, 2.0, 3.0}}, {`toBase64("hello")`, "aGVsbG8="}, @@ -207,6 +217,9 @@ func TestBuiltin_errors(t *testing.T) { {`min()`, `not enough arguments to call min`}, {`min(1, "2")`, `invalid argument for min (type string)`}, {`min([1, "2"])`, `invalid argument for min (type string)`}, + {`median(1..9, "t")`, "invalid argument for median (type string)"}, + {`mean("s", 1..9)`, "invalid argument for mean (type string)"}, + {`sum("s", "h")`, "invalid argument for sum (type string)"}, {`duration("error")`, `invalid duration`}, {`date("error")`, `invalid date`}, {`get()`, `invalid number of arguments (expected 2, got 0)`}, diff --git a/builtin/lib.go b/builtin/lib.go index b08c2ed2..9ff9478a 100644 --- a/builtin/lib.go +++ b/builtin/lib.go @@ -6,7 +6,7 @@ import ( "reflect" "strconv" - "github.com/expr-lang/expr/vm/runtime" + "github.com/expr-lang/expr/internal/deref" ) func Len(x any) any { @@ -254,45 +254,143 @@ func String(arg any) any { return fmt.Sprintf("%v", arg) } -func Max(args ...any) (any, error) { - return minMaxFunc("max", runtime.Less, args) -} +func sum(args ...any) (any, error) { + var total int + var fTotal float64 + + for _, arg := range args { + rv := reflect.ValueOf(deref.Deref(arg)) -func Min(args ...any) (any, error) { - return minMaxFunc("min", runtime.More, args) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elemSum, err := sum(rv.Index(i).Interface()) + if err != nil { + return nil, err + } + switch elemSum := elemSum.(type) { + case int: + total += elemSum + case float64: + fTotal += elemSum + } + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + total += int(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + total += int(rv.Uint()) + case reflect.Float32, reflect.Float64: + fTotal += rv.Float() + default: + return nil, fmt.Errorf("invalid argument for sum (type %T)", arg) + } + } + + if fTotal != 0.0 { + return fTotal + float64(total), nil + } + return total, nil } -func minMaxFunc(name string, fn func(any, any) bool, args []any) (any, error) { +func minMax(name string, fn func(any, any) bool, args ...any) (any, error) { var val any for _, arg := range args { - switch v := arg.(type) { - case []float32, []float64, []uint, []uint8, []uint16, []uint32, []uint64, []int, []int8, []int16, []int32, []int64: - rv := reflect.ValueOf(v) - if rv.Len() == 0 { - return nil, fmt.Errorf("not enough arguments to call %s", name) - } - arg = rv.Index(0).Interface() - for i := 1; i < rv.Len(); i++ { - elem := rv.Index(i).Interface() - if fn(arg, elem) { - arg = elem + rv := reflect.ValueOf(deref.Deref(arg)) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elemVal, err := minMax(name, fn, rv.Index(i).Interface()) + if err != nil { + return nil, err + } + switch elemVal.(type) { + case int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + float32, float64: + if elemVal != nil && (val == nil || fn(val, elemVal)) { + val = elemVal + } + default: + return nil, fmt.Errorf("invalid argument for %s (type %T)", name, elemVal) } + } - case []any: - var err error - if arg, err = minMaxFunc(name, fn, v); err != nil { - return nil, err + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + elemVal := rv.Interface() + if val == nil || fn(val, elemVal) { + val = elemVal } - case float32, float64, uint, uint8, uint16, uint32, uint64, int, int8, int16, int32, int64: default: if len(args) == 1 { - return arg, nil + return args[0], nil } - return nil, fmt.Errorf("invalid argument for %s (type %T)", name, v) - } - if val == nil || fn(val, arg) { - val = arg + return nil, fmt.Errorf("invalid argument for %s (type %T)", name, arg) } } return val, nil } + +func mean(args ...any) (int, float64, error) { + var total float64 + var count int + + for _, arg := range args { + rv := reflect.ValueOf(deref.Deref(arg)) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elemCount, elemSum, err := mean(rv.Index(i).Interface()) + if err != nil { + return 0, 0, err + } + total += elemSum + count += elemCount + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + total += float64(rv.Int()) + count++ + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + total += float64(rv.Uint()) + count++ + case reflect.Float32, reflect.Float64: + total += rv.Float() + count++ + default: + return 0, 0, fmt.Errorf("invalid argument for mean (type %T)", arg) + } + } + return count, total, nil +} + +func median(args ...any) ([]float64, error) { + var values []float64 + + for _, arg := range args { + rv := reflect.ValueOf(deref.Deref(arg)) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elems, err := median(rv.Index(i).Interface()) + if err != nil { + return nil, err + } + values = append(values, elems...) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + values = append(values, float64(rv.Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + values = append(values, float64(rv.Uint())) + case reflect.Float32, reflect.Float64: + values = append(values, rv.Float()) + default: + return nil, fmt.Errorf("invalid argument for median (type %T)", arg) + } + } + return values, nil +} diff --git a/builtin/validation.go b/builtin/validation.go new file mode 100644 index 00000000..057f247e --- /dev/null +++ b/builtin/validation.go @@ -0,0 +1,38 @@ +package builtin + +import ( + "fmt" + "reflect" + + "github.com/expr-lang/expr/internal/deref" +) + +func validateAggregateFunc(name string, args []reflect.Type) (reflect.Type, error) { + switch len(args) { + case 0: + return anyType, fmt.Errorf("not enough arguments to call %s", name) + default: + for _, arg := range args { + switch kind(deref.Type(arg)) { + case reflect.Interface, reflect.Array, reflect.Slice: + return anyType, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: + default: + return anyType, fmt.Errorf("invalid argument for %s (type %s)", name, arg) + } + } + return args[0], nil + } +} + +func validateRoundFunc(name string, args []reflect.Type) (reflect.Type, error) { + if len(args) != 1 { + return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) + } + switch kind(args[0]) { + case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: + return floatType, nil + default: + return anyType, fmt.Errorf("invalid argument for %s (type %s)", name, args[0]) + } +}