Skip to content

Commit

Permalink
executor: support window function nth_value (#9596)
Browse files Browse the repository at this point in the history
  • Loading branch information
alivxxx authored and zz-jason committed Mar 11, 2019
1 parent 0dada1e commit 6e8cd3c
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 3 deletions.
12 changes: 12 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag
return buildLastValue(windowFuncDesc, ordinal)
case ast.WindowFuncCumeDist:
return buildCumeDist(ordinal, orderByCols)
case ast.WindowFuncNthValue:
return buildNthValue(windowFuncDesc, ordinal)
default:
return Build(ctx, windowFuncDesc, ordinal)
}
Expand Down Expand Up @@ -374,3 +376,13 @@ func buildCumeDist(ordinal int, orderByCols []*expression.Column) AggFunc {
r := &cumeDist{baseAggFunc: base, rowComparer: buildRowComparer(orderByCols)}
return r
}

func buildNthValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
// Already checked when building the function description.
nth, _, _ := expression.GetUint64FromConstant(aggFuncDesc.Args[1])
return &nthValue{baseAggFunc: base, tp: aggFuncDesc.RetTp, nth: nth}
}
47 changes: 47 additions & 0 deletions executor/aggfuncs/func_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,50 @@ func (v *lastValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialR
}
return nil
}

type nthValue struct {
baseAggFunc

tp *types.FieldType
nth uint64
}

type partialResult4NthValue struct {
seenRows uint64
evaluator valueEvaluator
}

func (v *nthValue) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4NthValue{evaluator: buildValueEvaluator(v.tp)})
}

func (v *nthValue) ResetPartialResult(pr PartialResult) {
p := (*partialResult4NthValue)(pr)
p.seenRows = 0
}

func (v *nthValue) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
if v.nth == 0 {
return nil
}
p := (*partialResult4NthValue)(pr)
numRows := uint64(len(rowsInGroup))
if v.nth > p.seenRows && v.nth-p.seenRows <= numRows {
err := p.evaluator.evaluateRow(sctx, v.args[0], rowsInGroup[v.nth-p.seenRows-1])
if err != nil {
return err
}
}
p.seenRows += numRows
return nil
}

func (v *nthValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4NthValue)(pr)
if v.nth == 0 || p.seenRows < v.nth {
chk.AppendNull(v.ordinal)
} else {
p.evaluator.appendResult(chk, v.ordinal)
}
return nil
}
9 changes: 9 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,13 @@ func (s *testSuite2) TestWindowFunctions(c *C) {
result.Check(testkit.Rows("1 1 0.5", "1 2 0.5", "2 1 1", "2 2 1"))
result = tk.MustQuery("select a, b, cume_dist() over(order by a, b) from t")
result.Check(testkit.Rows("1 1 0.25", "1 2 0.5", "2 1 0.75", "2 2 1"))

result = tk.MustQuery("select a, nth_value(a, null) over() from t")
result.Check(testkit.Rows("1 <nil>", "1 <nil>", "2 <nil>", "2 <nil>"))
result = tk.MustQuery("select a, nth_value(a, 1) over() from t")
result.Check(testkit.Rows("1 1", "1 1", "2 1", "2 1"))
result = tk.MustQuery("select a, nth_value(a, 4) over() from t")
result.Check(testkit.Rows("1 2", "1 2", "2 2", "2 2"))
result = tk.MustQuery("select a, nth_value(a, 5) over() from t")
result.Check(testkit.Rows("1 <nil>", "1 <nil>", "2 <nil>", "2 <nil>"))
}
2 changes: 1 addition & 1 deletion expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) {
case ast.AggFuncGroupConcat:
a.typeInfer4GroupConcat(ctx)
case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow,
ast.WindowFuncFirstValue, ast.WindowFuncLastValue:
ast.WindowFuncFirstValue, ast.WindowFuncLastValue, ast.WindowFuncNthValue:
a.typeInfer4MaxMin(ctx)
case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
a.typeInfer4BitFuncs(ctx)
Expand Down
7 changes: 7 additions & 0 deletions expression/aggregation/window_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ type WindowFuncDesc struct {

// NewWindowFuncDesc creates a window function signature descriptor.
func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) *WindowFuncDesc {
if strings.ToLower(name) == ast.WindowFuncNthValue {
val, isNull, ok := expression.GetUint64FromConstant(args[1])
// nth_value does not allow `0`, but allows `null`.
if !ok || (val == 0 && !isNull) {
return nil
}
}
return &WindowFuncDesc{newBaseFuncDesc(ctx, name, args)}
}

Expand Down
33 changes: 33 additions & 0 deletions expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"unicode"

"github.com/pingcap/errors"
"github.com/pingcap/log"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/opcode"
Expand All @@ -28,6 +29,7 @@ import (
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/parser_driver"
"github.com/pingcap/tidb/util/chunk"
"go.uber.org/zap"
"golang.org/x/tools/container/intsets"
)

Expand Down Expand Up @@ -670,3 +672,34 @@ func RemoveDupExprs(ctx sessionctx.Context, exprs []Expression) []Expression {
}
return res
}

// GetUint64FromConstant gets a uint64 from constant expression.
func GetUint64FromConstant(expr Expression) (uint64, bool, bool) {
con, ok := expr.(*Constant)
if !ok {
log.Warn("not a constant expression", zap.Any("value", expr))
return 0, false, false
}
dt := con.Value
if con.DeferredExpr != nil {
var err error
dt, err = con.DeferredExpr.Eval(chunk.Row{})
if err != nil {
log.Warn("eval deferred expr failed", zap.Error(err))
return 0, false, false
}
}
switch dt.Kind() {
case types.KindNull:
return 0, true, true
case types.KindInt64:
val := dt.GetInt64()
if val < 0 {
return 0, false, false
}
return uint64(val), false, true
case types.KindUint64:
return dt.GetUint64(), false, true
}
return 0, false, false
}
8 changes: 6 additions & 2 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2783,8 +2783,9 @@ func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, expr *ast.WindowFu
return nil, nil, nil, nil, err
}
p = np
if col, ok := newArg.(*expression.Column); ok {
newArgList = append(newArgList, col)
switch newArg.(type) {
case *expression.Column, *expression.Constant:
newArgList = append(newArgList, newArg)
continue
}
proj.Exprs = append(proj.Exprs, newArg)
Expand Down Expand Up @@ -2966,6 +2967,9 @@ func (b *PlanBuilder) buildWindowFunction(p LogicalPlan, expr *ast.WindowFuncExp
return nil, err
}
desc := aggregation.NewWindowFuncDesc(b.ctx, expr.F, args)
if desc == nil {
return nil, ErrWrongArguments.GenWithStackByArgs(expr.F)
}
// TODO: Check if the function is aggregation function after we support more functions.
desc.WrapCastForAggArgs(b.ctx)
window := LogicalWindow{
Expand Down
8 changes: 8 additions & 0 deletions planner/core/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2199,6 +2199,14 @@ func (s *testPlanSuite) TestWindowFunction(c *C) {
sql: "select row_number() over(rows between 1 preceding and 1 following) from t",
result: "TableReader(Table(t))->Window(row_number() over())->Projection",
},
{
sql: "select nth_value(a, 1.0) over() from t",
result: "[planner:1210]Incorrect arguments to nth_value",
},
{
sql: "select nth_value(a, 0) over() from t",
result: "[planner:1210]Incorrect arguments to nth_value",
},
}

s.Parser.EnableWindowFunc(true)
Expand Down

0 comments on commit 6e8cd3c

Please sign in to comment.