From 760d28af9b0cf40bcfdc7bea2f20e8fb6035b94c Mon Sep 17 00:00:00 2001 From: lovewin99 Date: Tue, 9 Jul 2019 13:50:10 +0800 Subject: [PATCH] fix issue:#11102 Unexcepted result in `SELECT ... CASE WHEN ... ELSE NULL...` --- expression/constant_fold.go | 58 ++++++++++++++++++++++++++++++------- session/session_test.go | 4 +-- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/expression/constant_fold.go b/expression/constant_fold.go index 2437b269e31b6..4ce3765ca9427 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -15,6 +15,7 @@ package expression import ( "github.com/pingcap/parser/ast" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" @@ -23,18 +24,11 @@ import ( // specialFoldHandler stores functions for special UDF to constant fold var specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){} -// specialNullRejectCheck stores function names for special UDF to skip constant fold. -var specialNullRejectCheck = map[string]struct{}{} - func init() { specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){ ast.If: ifFoldHandler, ast.Ifnull: ifNullFoldHandler, - } - - specialNullRejectCheck = map[string]struct{}{ - ast.NullEQ: {}, - ast.Case: {}, + ast.Case: caseWhenHandler, } } @@ -86,6 +80,51 @@ func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) { return expr, isDeferredConst } +func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { + args, l := expr.GetArgs(), len(expr.GetArgs()) + firstCondition := true + isDeferredConst := false + for i := 0; i < l-1; i += 2 { + foldedArg, isDeferred := foldConstant(args[i]) + expr.GetArgs()[i] = foldedArg + isDeferredConst = isDeferredConst || isDeferred + if _, isConst := foldedArg.(*Constant); isConst { + condition, isNull, err := args[i].EvalInt(expr.GetCtx(), chunk.Row{}) + if err != nil { + return expr, false + } + if firstCondition && condition != 0 && !isNull { + return retProcess(args[i+1], expr.GetType()) + } + } else { + firstCondition = false + } + foldedArg1, isDeferred1 := foldConstant(args[i+1]) + expr.GetArgs()[i+1] = foldedArg1 + isDeferredConst = isDeferredConst || isDeferred1 + } + + if l%2 == 1 && firstCondition { + return retProcess(args[l-1], expr.GetType()) + } else if l%2 == 1 { + foldedArg, isDeferred := foldConstant(args[l-1]) + expr.GetArgs()[l-1] = foldedArg + isDeferredConst = isDeferredConst || isDeferred + } + + return expr, isDeferredConst +} + +func retProcess(expr Expression, retType *types.FieldType) (Expression, bool) { + foldedExpr, b := foldConstant(expr) + if fc, isConst := foldedExpr.(*Constant); isConst { + fc.RetType = retType + return fc, b + } + + return foldedExpr, b +} + func foldConstant(expr Expression) (Expression, bool) { switch x := expr.(type) { case *ScalarFunction: @@ -112,8 +151,7 @@ func foldConstant(expr Expression) (Expression, bool) { isDeferredConst = isDeferredConst || isDeferred } if !allConstArg { - _, ok := specialNullRejectCheck[x.FuncName.L] - if !hasNullArg || !sc.InNullRejectCheck || ok { + if !hasNullArg || !sc.InNullRejectCheck || x.FuncName.L == ast.NullEQ { return expr, isDeferredConst } constArgs := make([]Expression, len(args)) diff --git a/session/session_test.go b/session/session_test.go index f2dc3a9691746..4c5e8b2a21057 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -2733,7 +2733,5 @@ func (s *testSessionSuite) TestFuncCaseWithLeftJoin(c *C) { tk.MustExec("create table kankan2(id int, h1 text)") tk.MustExec("insert into kankan2 values(2, 'z')") - tk.MustQuery("select * from (select t1.id, t2.h1, case when t1.name='b' then 'case2' when t1.name='a' then " + - "'case1' else null end as flag from kankan1 t1 left join kankan2 t2 on t1.id = t2.id) t3 where t3.flag='case1' " + - "order by t3.id").Check(testkit.Rows("1 case1", "2 z case1")) + tk.MustQuery("select t1.id from kankan1 t1 left join kankan2 t2 on t1.id = t2.id where (case when t1.name='b' then 'case2' when t1.name='a' then 'case1' else NULL end) = 'case1' order by t1.id").Check(testkit.Rows("1", "2")) }