From ffb64490b554ed243b9b205a8c7e13887fe14008 Mon Sep 17 00:00:00 2001 From: Song Gao Date: Thu, 19 Dec 2024 16:28:07 +0800 Subject: [PATCH] add inc_avg Signed-off-by: Song Gao --- internal/binder/function/binder.go | 9 -- internal/binder/function/funcs_agg.go | 23 ---- internal/binder/function/funcs_inc_agg.go | 101 ++++++++++++++++++ .../binder/function/funcs_inc_agg_test.go | 66 ++++++++++++ 4 files changed, 167 insertions(+), 32 deletions(-) create mode 100644 internal/binder/function/funcs_inc_agg.go create mode 100644 internal/binder/function/funcs_inc_agg_test.go diff --git a/internal/binder/function/binder.go b/internal/binder/function/binder.go index 0c6d9e852f..6607bc83a1 100644 --- a/internal/binder/function/binder.go +++ b/internal/binder/function/binder.go @@ -135,12 +135,3 @@ func GetFuncType(funcName string) ast.FuncType { } return ast.FuncTypeUnknown } - -var supportedIncAggFunc = map[string]struct{}{ - "count": {}, -} - -func IsSupportedIncAgg(name string) bool { - _, ok := supportedIncAggFunc[name] - return ok -} diff --git a/internal/binder/function/funcs_agg.go b/internal/binder/function/funcs_agg.go index 74538f514d..c3494a0a2b 100644 --- a/internal/binder/function/funcs_agg.go +++ b/internal/binder/function/funcs_agg.go @@ -379,26 +379,3 @@ func registerAggFunc() { check: returnNilIfHasAnyNil, } } - -func registerIncAggFunc() { - builtins["inc_count"] = builtinFunc{ - fType: ast.FuncTypeScalar, - exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) { - key := fmt.Sprintf("%v_inc_count", ctx.GetFuncId()) - v, err := ctx.GetState(key) - if err != nil { - return err, false - } - var c int64 - if v == nil { - c = 1 - } else { - c = v.(int64) + 1 - } - ctx.PutState(key, c) - return c, true - }, - val: ValidateOneArg, - check: returnNilIfHasAnyNil, - } -} diff --git a/internal/binder/function/funcs_inc_agg.go b/internal/binder/function/funcs_inc_agg.go new file mode 100644 index 0000000000..53282c0d90 --- /dev/null +++ b/internal/binder/function/funcs_inc_agg.go @@ -0,0 +1,101 @@ +// Copyright 2024 EMQ Technologies Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "fmt" + + "github.com/lf-edge/ekuiper/contract/v2/api" + + "github.com/lf-edge/ekuiper/v2/pkg/ast" + "github.com/lf-edge/ekuiper/v2/pkg/cast" +) + +var supportedIncAggFunc = map[string]struct{}{ + "count": {}, + "avg": {}, +} + +func IsSupportedIncAgg(name string) bool { + _, ok := supportedIncAggFunc[name] + return ok +} + +func registerIncAggFunc() { + builtins["inc_count"] = builtinFunc{ + fType: ast.FuncTypeScalar, + exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) { + c, err := incrementalCount(ctx, args[0]) + if err != nil { + return err, false + } + return c, true + }, + val: ValidateOneArg, + check: returnNilIfHasAnyNil, + } + builtins["inc_avg"] = builtinFunc{ + fType: ast.FuncTypeScalar, + exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) { + arg0, err := cast.ToFloat64(args[0], cast.CONVERT_ALL) + if err != nil { + return err, false + } + count, err := incrementalCount(ctx, arg0) + if err != nil { + return err, false + } + sum, err := incrementalSum(ctx, arg0) + if err != nil { + return err, false + } + return sum / float64(count), true + }, + val: ValidateOneNumberArg, + check: returnNilIfHasAnyNil, + } +} + +func incrementalCount(ctx api.FunctionContext, arg interface{}) (int64, error) { + key := fmt.Sprintf("%v_inc_count", ctx.GetFuncId()) + v, err := ctx.GetState(key) + if err != nil { + return 0, err + } + var c int64 + if v == nil { + c = 1 + } else { + c = v.(int64) + 1 + } + ctx.PutState(key, c) + return c, nil +} + +func incrementalSum(ctx api.FunctionContext, arg float64) (float64, error) { + key := fmt.Sprintf("%v_inc_sum", ctx.GetFuncId()) + v, err := ctx.GetState(key) + if err != nil { + return 0, err + } + var sum float64 + if v == nil { + sum = arg + } else { + sum = v.(float64) + arg + } + ctx.PutState(key, sum) + return sum, nil +} diff --git a/internal/binder/function/funcs_inc_agg_test.go b/internal/binder/function/funcs_inc_agg_test.go new file mode 100644 index 0000000000..034e4460e9 --- /dev/null +++ b/internal/binder/function/funcs_inc_agg_test.go @@ -0,0 +1,66 @@ +// Copyright 2024 EMQ Technologies Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/lf-edge/ekuiper/v2/internal/conf" + "github.com/lf-edge/ekuiper/v2/internal/pkg/def" + kctx "github.com/lf-edge/ekuiper/v2/internal/topo/context" + "github.com/lf-edge/ekuiper/v2/internal/topo/state" +) + +func TestIncAggFunction(t *testing.T) { + contextLogger := conf.Log.WithField("rule", "testExec") + registerIncAggFunc() + testcases := []struct { + funcName string + args1 []interface{} + output1 interface{} + args2 []interface{} + output2 interface{} + }{ + { + funcName: "inc_count", + args1: []interface{}{1}, + output1: int64(1), + args2: []interface{}{1}, + output2: int64(2), + }, + { + funcName: "inc_avg", + args1: []interface{}{1}, + output1: float64(1), + args2: []interface{}{3}, + output2: float64(2), + }, + } + for index, tc := range testcases { + ctx := kctx.WithValue(kctx.Background(), kctx.LoggerKey, contextLogger) + tempStore, _ := state.CreateStore(tc.funcName, def.AtMostOnce) + fctx := kctx.NewDefaultFuncContext(ctx.WithMeta("mockRule0", "test", tempStore), index) + f, ok := builtins[tc.funcName] + require.True(t, ok) + got1, ok := f.exec(fctx, tc.args1) + require.True(t, ok) + require.Equal(t, tc.output1, got1) + got2, ok := f.exec(fctx, tc.args2) + require.True(t, ok) + require.Equal(t, tc.output2, got2) + } +}