Skip to content

Commit

Permalink
feat: support incremental avg (#3457)
Browse files Browse the repository at this point in the history
Signed-off-by: Song Gao <disxiaofei@163.com>
  • Loading branch information
Yisaer authored Dec 23, 2024
1 parent c361fdd commit e4d7792
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 32 deletions.
9 changes: 0 additions & 9 deletions internal/binder/function/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,3 @@ func GetFuncType(funcName string) ast.FuncType {
}
return ast.FuncTypeUnknown
}

var supportedIncAggFunc = map[string]struct{}{
"count": {},
}

func IsSupportedIncAgg(name string) bool {
_, ok := supportedIncAggFunc[name]
return ok
}
23 changes: 0 additions & 23 deletions internal/binder/function/funcs_agg.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,26 +379,3 @@ func registerAggFunc() {
check: returnNilIfHasAnyNil,
}
}

func registerIncAggFunc() {
builtins["inc_count"] = builtinFunc{
fType: ast.FuncTypeScalar,
exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
key := fmt.Sprintf("%v_inc_count", ctx.GetFuncId())
v, err := ctx.GetState(key)
if err != nil {
return err, false
}
var c int64
if v == nil {
c = 1
} else {
c = v.(int64) + 1
}
ctx.PutState(key, c)
return c, true
},
val: ValidateOneArg,
check: returnNilIfHasAnyNil,
}
}
101 changes: 101 additions & 0 deletions internal/binder/function/funcs_inc_agg.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright 2024 EMQ Technologies Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package function

import (
"fmt"

"github.com/lf-edge/ekuiper/contract/v2/api"

"github.com/lf-edge/ekuiper/v2/pkg/ast"
"github.com/lf-edge/ekuiper/v2/pkg/cast"
)

var supportedIncAggFunc = map[string]struct{}{
"count": {},
"avg": {},
}

func IsSupportedIncAgg(name string) bool {
_, ok := supportedIncAggFunc[name]
return ok
}

func registerIncAggFunc() {
builtins["inc_count"] = builtinFunc{
fType: ast.FuncTypeScalar,
exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
c, err := incrementalCount(ctx, args[0])
if err != nil {
return err, false
}
return c, true
},
val: ValidateOneArg,
check: returnNilIfHasAnyNil,
}
builtins["inc_avg"] = builtinFunc{
fType: ast.FuncTypeScalar,
exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
arg0, err := cast.ToFloat64(args[0], cast.CONVERT_ALL)
if err != nil {
return err, false
}
count, err := incrementalCount(ctx, arg0)
if err != nil {
return err, false
}
sum, err := incrementalSum(ctx, arg0)
if err != nil {
return err, false
}
return sum / float64(count), true
},
val: ValidateOneNumberArg,
check: returnNilIfHasAnyNil,
}
}

func incrementalCount(ctx api.FunctionContext, arg interface{}) (int64, error) {
key := fmt.Sprintf("%v_inc_count", ctx.GetFuncId())
v, err := ctx.GetState(key)
if err != nil {
return 0, err
}
var c int64
if v == nil {
c = 1
} else {
c = v.(int64) + 1
}
ctx.PutState(key, c)
return c, nil
}

func incrementalSum(ctx api.FunctionContext, arg float64) (float64, error) {
key := fmt.Sprintf("%v_inc_sum", ctx.GetFuncId())
v, err := ctx.GetState(key)
if err != nil {
return 0, err
}
var sum float64
if v == nil {
sum = arg
} else {
sum = v.(float64) + arg
}
ctx.PutState(key, sum)
return sum, nil
}
66 changes: 66 additions & 0 deletions internal/binder/function/funcs_inc_agg_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2024 EMQ Technologies Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package function

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/lf-edge/ekuiper/v2/internal/conf"
"github.com/lf-edge/ekuiper/v2/internal/pkg/def"
kctx "github.com/lf-edge/ekuiper/v2/internal/topo/context"
"github.com/lf-edge/ekuiper/v2/internal/topo/state"
)

func TestIncAggFunction(t *testing.T) {
contextLogger := conf.Log.WithField("rule", "testExec")
registerIncAggFunc()
testcases := []struct {
funcName string
args1 []interface{}
output1 interface{}
args2 []interface{}
output2 interface{}
}{
{
funcName: "inc_count",
args1: []interface{}{1},
output1: int64(1),
args2: []interface{}{1},
output2: int64(2),
},
{
funcName: "inc_avg",
args1: []interface{}{1},
output1: float64(1),
args2: []interface{}{3},
output2: float64(2),
},
}
for index, tc := range testcases {
ctx := kctx.WithValue(kctx.Background(), kctx.LoggerKey, contextLogger)
tempStore, _ := state.CreateStore(tc.funcName, def.AtMostOnce)
fctx := kctx.NewDefaultFuncContext(ctx.WithMeta("mockRule0", "test", tempStore), index)
f, ok := builtins[tc.funcName]
require.True(t, ok)
got1, ok := f.exec(fctx, tc.args1)
require.True(t, ok)
require.Equal(t, tc.output1, got1)
got2, ok := f.exec(fctx, tc.args2)
require.True(t, ok)
require.Equal(t, tc.output2, got2)
}
}

0 comments on commit e4d7792

Please sign in to comment.