diff --git a/expression/constant_fold.go b/expression/constant_fold.go index a6f97092d0285..c9dbcbbd95ef6 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -81,12 +81,11 @@ func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) { func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { args, l := expr.GetArgs(), len(expr.GetArgs()) - isDeferredConst, hasNonConstCondition := false, true + isDeferred, isDeferredConst, hasNonConstCondition := false, false, true for i := 0; i < l-1; i += 2 { - foldedArg, isDeferred := foldConstant(args[i]) - expr.GetArgs()[i] = foldedArg + expr.GetArgs()[i], isDeferred = foldConstant(args[i]) isDeferredConst = isDeferredConst || isDeferred - if _, isConst := foldedArg.(*Constant); isConst && hasNonConstCondition { + if _, isConst := expr.GetArgs()[i].(*Constant); isConst && hasNonConstCondition { // If the condition is const and true, and the previous conditions // has no expr, then the folded execution body is returned, otherwise // the arguments of the casewhen are folded and replaced. @@ -96,9 +95,8 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { } if val != 0 && !isNull { foldedExpr, isDeferred := foldConstant(args[i+1]) - foldedExpr.GetType().Decimal = expr.GetType().Decimal isDeferredConst = isDeferredConst || isDeferred - return foldedExpr, isDeferredConst + return BuildCastFunction(expr.GetCtx(), foldedExpr, expr.GetType()), isDeferredConst } } else { hasNonConstCondition = false @@ -112,12 +110,10 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { // is const and false, then the folded else execution body is returned. otherwise // the execution body of the else are folded and replaced. foldedExpr, isDeferred := foldConstant(args[l-1]) - foldedExpr.GetType().Decimal = expr.GetType().Decimal isDeferredConst = isDeferredConst || isDeferred - return foldedExpr, isDeferredConst + return BuildCastFunction(expr.GetCtx(), foldedExpr, expr.GetType()), isDeferredConst } else if l%2 == 1 { - foldedArg, isDeferred := foldConstant(args[l-1]) - expr.GetArgs()[l-1] = foldedArg + expr.GetArgs()[l-1], isDeferred = foldConstant(args[l-1]) isDeferredConst = isDeferredConst || isDeferred }