From ea7eb5175fe38e26f77d7cfaec5dc0cdae9a2546 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Tue, 13 Jun 2023 08:52:40 +0200 Subject: [PATCH] tpl/math: Allow variadic math functions to take slice args, add math.Product, math.Sum * Update math.Min and math.Max to allow 1 or more args, either scalar or slice, or combination of the two * Add math.Sum (allow 1 or more args, either scalar or slice, or combination of the two) * Add math.Product (allow 1 or more args, either scalar or slice, or combination of the two) Fixes #11030 --- tpl/math/math.go | 105 +++++++++++++++++++++++++++--------------- tpl/math/math_test.go | 54 +++++++++++++++++++++- 2 files changed, 121 insertions(+), 38 deletions(-) diff --git a/tpl/math/math.go b/tpl/math/math.go index d73f212a658..8739e76eace 100644 --- a/tpl/math/math.go +++ b/tpl/math/math.go @@ -16,7 +16,9 @@ package math import ( "errors" + "fmt" "math" + "reflect" "sync/atomic" _math "github.com/gohugoio/hugo/common/math" @@ -85,48 +87,30 @@ func (ns *Namespace) Log(n any) (float64, error) { return math.Log(af), nil } -// Max returns the greater of the multivalued numbers n1 and n2 or more values. +// Max returns the greater of all numbers in inputs. Any slices in inputs are flattened. func (ns *Namespace) Max(inputs ...any) (maximum float64, err error) { - if len(inputs) < 2 { - err = errMustTwoNumbersError - return - } - var value float64 - for index, input := range inputs { - value, err = cast.ToFloat64E(input) - if err != nil { - err = errors.New("Max operator can't be used with non-float value") - return - } - if index == 0 { - maximum = value - continue - } - maximum = math.Max(value, maximum) - } - return + return ns.applyOpToScalarsOrSlices("Max", math.Max, inputs...) } -// Min returns the smaller of multivalued numbers n1 and n2 or more values. +// Min returns the smaller of all numbers in inputs. Any slices in inputs are flattened. func (ns *Namespace) Min(inputs ...any) (minimum float64, err error) { - if len(inputs) < 2 { - err = errMustTwoNumbersError - return + return ns.applyOpToScalarsOrSlices("Min", math.Min, inputs...) +} + +// Sum returns the sum of all numbers in inputs. Any slices in inputs are flattened. +func (ns *Namespace) Sum(inputs ...any) (sum float64, err error) { + fn := func(x, y float64) float64 { + return x + y } - var value float64 - for index, input := range inputs { - value, err = cast.ToFloat64E(input) - if err != nil { - err = errors.New("Max operator can't be used with non-float value") - return - } - if index == 0 { - minimum = value - continue - } - minimum = math.Min(value, minimum) + return ns.applyOpToScalarsOrSlices("Sum", fn, inputs...) +} + +// Product returns the product of all numbers in inputs. Any slices in inputs are flattened. +func (ns *Namespace) Product(inputs ...any) (product float64, err error) { + fn := func(x, y float64) float64 { + return x * y } - return + return ns.applyOpToScalarsOrSlices("Product", fn, inputs...) } // Mod returns n1 % n2. @@ -197,6 +181,55 @@ func (ns *Namespace) Sub(inputs ...any) (any, error) { return ns.doArithmetic(inputs, '-') } +func (ns *Namespace) applyOpToScalarsOrSlices(opName string, op func(x, y float64) float64, inputs ...any) (result float64, err error) { + var i int + for _, input := range inputs { + var values []float64 + values, err = ns.toFloatsE(input) + if err != nil { + err = fmt.Errorf("%s operator can't be used with non-float values", opName) + return + } + for _, value := range values { + i++ + if i == 1 { + result = value + continue + } + result = op(result, value) + } + } + + if i < 2 { + err = errMustTwoNumbersError + return + } + return + +} + +func (ns *Namespace) toFloatsE(v any) ([]float64, error) { + vv := reflect.ValueOf(v) + switch vv.Kind() { + case reflect.Slice, reflect.Array: + var floats []float64 + for i := 0; i < vv.Len(); i++ { + f, err := cast.ToFloat64E(vv.Index(i).Interface()) + if err != nil { + return nil, err + } + floats = append(floats, f) + } + return floats, nil + default: + f, err := cast.ToFloat64E(v) + if err != nil { + return nil, err + } + return []float64{f}, nil + } +} + func (ns *Namespace) doArithmetic(inputs []any, operation rune) (value any, err error) { if len(inputs) < 2 { return nil, errMustTwoNumbersError diff --git a/tpl/math/math_test.go b/tpl/math/math_test.go index fad86938df0..8d299f794e7 100644 --- a/tpl/math/math_test.go +++ b/tpl/math/math_test.go @@ -421,6 +421,11 @@ func TestMax(t *testing.T) { {[]any{0, "a"}, false}, {[]any{"a", 0}, false}, {[]any{"a", "b"}, false}, + // Issue #11030 + {[]any{7, []any{3, 4}}, 7.0}, + {[]any{8, []any{3, 12}, 3}, 12.0}, + {[]any{[]any{3, 5, 2}}, 5.0}, + {[]any{3, []int{3, 6}, 3}, 6.0}, // miss values {[]any{}, false}, {[]any{0}, false}, @@ -437,7 +442,7 @@ func TestMax(t *testing.T) { } c.Assert(err, qt.IsNil) - c.Assert(result, qt.Equals, test.expect) + c.Assert(result, qt.Equals, test.expect, qt.Commentf("values: %v", test.values)) } } @@ -468,6 +473,11 @@ func TestMin(t *testing.T) { {[]any{0, "a"}, false}, {[]any{"a", 0}, false}, {[]any{"a", "b"}, false}, + // Issue #11030 + {[]any{1, []any{3, 4}}, 1.0}, + {[]any{8, []any{3, 2}, 3}, 2.0}, + {[]any{[]any{3, 2, 2}}, 2.0}, + {[]any{8, []int{3, 2}, 3}, 2.0}, // miss values {[]any{}, false}, {[]any{0}, false}, @@ -484,7 +494,47 @@ func TestMin(t *testing.T) { continue } - c.Assert(err, qt.IsNil) + c.Assert(err, qt.IsNil, qt.Commentf("values: %v", test.values)) c.Assert(result, qt.Equals, test.expect) } } + +func TestSum(t *testing.T) { + t.Parallel() + c := qt.New(t) + + ns := New() + + mustSum := func(values ...any) any { + result, err := ns.Sum(values...) + c.Assert(err, qt.IsNil) + return result + } + + c.Assert(mustSum(1, 2, 3), qt.Equals, 6.0) + c.Assert(mustSum(1, 2, 3.0), qt.Equals, 6.0) + c.Assert(mustSum(1, 2, []any{3, 4}), qt.Equals, 10.0) + + _, err := ns.Sum(1) + c.Assert(err, qt.Not(qt.IsNil)) +} + +func TestProduct(t *testing.T) { + t.Parallel() + c := qt.New(t) + + ns := New() + + mustProduct := func(values ...any) any { + result, err := ns.Product(values...) + c.Assert(err, qt.IsNil) + return result + } + + c.Assert(mustProduct(2, 2, 3), qt.Equals, 12.0) + c.Assert(mustProduct(1, 2, 3.0), qt.Equals, 6.0) + c.Assert(mustProduct(1, 2, []any{3, 4}), qt.Equals, 24.0) + + _, err := ns.Product(1) + c.Assert(err, qt.Not(qt.IsNil)) +}