diff --git a/expression/constant_fold.go b/expression/constant_fold.go index 4c0c906744b83..a6e000bff422f 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -23,11 +23,19 @@ import ( // specialFoldHandler stores functions for special UDF to constant fold var specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){} +// specialNullRejectCheck stores function names for special UDF to skip constant fold. +var specialNullRejectCheck = map[string]struct{}{} + func init() { specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){ ast.If: ifFoldHandler, ast.Ifnull: ifNullFoldHandler, } + + specialNullRejectCheck = map[string]struct{}{ + ast.NullEQ: struct{}{}, + ast.Case: struct{}{}, + } } // FoldConstant does constant folding optimization on an expression excluding deferred ones. @@ -104,7 +112,8 @@ func foldConstant(expr Expression) (Expression, bool) { isDeferredConst = isDeferredConst || isDeferred } if !allConstArg { - if !hasNullArg || !sc.InNullRejectCheck || x.FuncName.L == ast.NullEQ { + _, ok := specialNullRejectCheck[x.FuncName.L] + if !hasNullArg || !sc.InNullRejectCheck || ok { return expr, isDeferredConst } constArgs := make([]Expression, len(args)) diff --git a/session/session_test.go b/session/session_test.go index 514e374ee99c4..98c8efb4a999b 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -2722,3 +2722,18 @@ func (s *testSessionSuite) TestLoadClientInteractive(c *C) { tk.Se.GetSessionVars().ClientCapability = tk.Se.GetSessionVars().ClientCapability | mysql.ClientInteractive tk.MustQuery("select @@wait_timeout").Check(testkit.Rows("28800")) } + +func (s *testSessionSuite) TestFuncCaseWithLeftJoin(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + + tk.MustExec("create table kankan1(id int, name text)") + tk.MustExec("insert into kankan1 values(1, 'a')") + tk.MustExec("insert into kankan1 values(2, 'a')") + + tk.MustExec("create table kankan2(id int, h1 text)") + tk.MustExec("insert into kankan2 values(2, 'z')") + + tk.MustQuery( "select * from (select t1.id, t2.h1, case when t1.name='b' then 'case2' when t1.name='a' then " + + "'case1' else null end as flag from kankan1 t1 left join kankan2 t2 on t1.id = t2.id) t3 where t3.flag='case1' " + + "order by t3.id").Check(testkit.Rows("1 case1", "2 z case1")) +}