Skip to content

Commit

Permalink
expression: fix double negation error when PushDownNot (pingcap#16094) (
Browse files Browse the repository at this point in the history
  • Loading branch information
sre-bot authored Apr 16, 2020
1 parent fdf0fb4 commit bbc3206
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 44 deletions.
4 changes: 2 additions & 2 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}
argTps := make([]types.EvalType, 0, l)
for i := 0; i < l-1; i += 2 {
if args[i], err = wrapWithIsTrue(ctx, true, args[i]); err != nil {
if args[i], err = wrapWithIsTrue(ctx, true, args[i], false); err != nil {
return nil, err
}
argTps = append(argTps, types.ETInt, tp)
Expand Down Expand Up @@ -474,7 +474,7 @@ func (c *ifFunctionClass) getFunction(ctx sessionctx.Context, args []Expression)
}
retTp := InferType4ControlFuncs(args[1].GetType(), args[2].GetType())
evalTps := retTp.EvalType()
args[0], err = wrapWithIsTrue(ctx, true, args[0])
args[0], err = wrapWithIsTrue(ctx, true, args[0], false)
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ func (c *logicAndFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err != nil {
return nil, err
}
args[0], err = wrapWithIsTrue(ctx, true, args[0])
args[0], err = wrapWithIsTrue(ctx, true, args[0], false)
if err != nil {
return nil, errors.Trace(err)
}
args[1], err = wrapWithIsTrue(ctx, true, args[1])
args[1], err = wrapWithIsTrue(ctx, true, args[1], false)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -117,11 +117,11 @@ func (c *logicOrFunctionClass) getFunction(ctx sessionctx.Context, args []Expres
if err != nil {
return nil, err
}
args[0], err = wrapWithIsTrue(ctx, true, args[0])
args[0], err = wrapWithIsTrue(ctx, true, args[0], false)
if err != nil {
return nil, errors.Trace(err)
}
args[1], err = wrapWithIsTrue(ctx, true, args[1])
args[1], err = wrapWithIsTrue(ctx, true, args[1], false)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
12 changes: 10 additions & 2 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,17 @@ func IsBinaryLiteral(expr Expression) bool {
// The `keepNull` controls what the istrue function will return when `arg` is null:
// 1. keepNull is true and arg is null, the istrue function returns null.
// 2. keepNull is false and arg is null, the istrue function returns 0.
func wrapWithIsTrue(ctx sessionctx.Context, keepNull bool, arg Expression) (Expression, error) {
// The `wrapForInt` indicates whether we need to wrapIsTrue for non-logical Expression with int type.
func wrapWithIsTrue(ctx sessionctx.Context, keepNull bool, arg Expression, wrapForInt bool) (Expression, error) {
if arg.GetType().EvalType() == types.ETInt {
return arg, nil
if !wrapForInt {
return arg, nil
}
if child, ok := arg.(*ScalarFunction); ok {
if _, isLogicalOp := logicalOps[child.FuncName.L]; isLogicalOp {
return arg, nil
}
}
}
fc := &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth, keepNull}
f, err := fc.getFunction(ctx, []Expression{arg})
Expand Down
12 changes: 11 additions & 1 deletion expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4055,7 +4055,7 @@ func (s *testIntegrationSuite) TestFilterExtractFromDNF(c *C) {
selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection)
conds := make([]expression.Expression, 0, len(selection.Conditions))
for _, cond := range selection.Conditions {
conds = append(conds, expression.PushDownNot(sctx, cond, false))
conds = append(conds, expression.PushDownNot(sctx, cond))
}
afterFunc := expression.ExtractFiltersFromDNFs(sctx, conds)
sort.Slice(afterFunc, func(i, j int) bool {
Expand Down Expand Up @@ -4919,6 +4919,16 @@ func (s *testIntegrationSuite) TestValuesForBinaryLiteral(c *C) {
tk.MustExec("drop table testValuesBinary;")
}

func (s *testIntegrationSuite) TestIssue15725(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test;")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int)")
tk.MustExec("insert into t values(2)")
tk.MustQuery("select * from t where (not not a) = a").Check(testkit.Rows())
tk.MustQuery("select * from t where (not not not not a) = a").Check(testkit.Rows())
}

func (s *testIntegrationSuite) TestIssue15790(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test;")
Expand Down
101 changes: 74 additions & 27 deletions expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,33 @@ func timeZone2Duration(tz string) time.Duration {
return time.Duration(sign) * (time.Duration(h)*time.Hour + time.Duration(m)*time.Minute)
}

var logicalOps = map[string]struct{}{
ast.LT: {},
ast.GE: {},
ast.GT: {},
ast.LE: {},
ast.EQ: {},
ast.NE: {},
ast.UnaryNot: {},
ast.LogicAnd: {},
ast.LogicOr: {},
ast.LogicXor: {},
ast.In: {},
ast.IsNull: {},
ast.IsTruth: {},
ast.IsFalsity: {},
ast.Like: {},
}

var oppositeOp = map[string]string{
ast.LT: ast.GE,
ast.GE: ast.LT,
ast.GT: ast.LE,
ast.LE: ast.GT,
ast.EQ: ast.NE,
ast.NE: ast.EQ,
ast.LT: ast.GE,
ast.GE: ast.LT,
ast.GT: ast.LE,
ast.LE: ast.GT,
ast.EQ: ast.NE,
ast.NE: ast.EQ,
ast.LogicOr: ast.LogicAnd,
ast.LogicAnd: ast.LogicOr,
}

// a op b is equal to b symmetricOp a
Expand All @@ -289,46 +309,73 @@ var symmetricOp = map[opcode.Op]opcode.Op{
opcode.NullEQ: opcode.NullEQ,
}

func doPushDownNot(ctx sessionctx.Context, exprs []Expression, not bool) []Expression {
func pushNotAcrossArgs(ctx sessionctx.Context, exprs []Expression, not bool) ([]Expression, bool) {
newExprs := make([]Expression, 0, len(exprs))
flag := false
for _, expr := range exprs {
newExprs = append(newExprs, PushDownNot(ctx, expr, not))
newExpr, changed := pushNotAcrossExpr(ctx, expr, not)
flag = changed || flag
newExprs = append(newExprs, newExpr)
}
return newExprs
return newExprs, flag
}

// PushDownNot pushes the `not` function down to the expression's arguments.
func PushDownNot(ctx sessionctx.Context, expr Expression, not bool) Expression {
// pushNotAcrossExpr try to eliminate the NOT expr in expression tree.
// Input `not` indicates whether there's a `NOT` be pushed down.
// Output `changed` indicates whether the output expression differs from the
// input `expr` because of the pushed-down-not.
func pushNotAcrossExpr(ctx sessionctx.Context, expr Expression, not bool) (_ Expression, changed bool) {
if f, ok := expr.(*ScalarFunction); ok {
switch f.FuncName.L {
case ast.UnaryNot:
return PushDownNot(f.GetCtx(), f.GetArgs()[0], !not)
child, err := wrapWithIsTrue(ctx, true, f.GetArgs()[0], true)
if err != nil {
return expr, false
}
var childExpr Expression
childExpr, changed = pushNotAcrossExpr(f.GetCtx(), child, !not)
if !changed && !not {
return expr, false
}
return childExpr, true
case ast.LT, ast.GE, ast.GT, ast.LE, ast.EQ, ast.NE:
if not {
return NewFunctionInternal(f.GetCtx(), oppositeOp[f.FuncName.L], f.GetType(), f.GetArgs()...)
return NewFunctionInternal(f.GetCtx(), oppositeOp[f.FuncName.L], f.GetType(), f.GetArgs()...), true
}
newArgs := doPushDownNot(f.GetCtx(), f.GetArgs(), false)
return NewFunctionInternal(f.GetCtx(), f.FuncName.L, f.GetType(), newArgs...)
case ast.LogicAnd:
if not {
newArgs := doPushDownNot(f.GetCtx(), f.GetArgs(), true)
return NewFunctionInternal(f.GetCtx(), ast.LogicOr, f.GetType(), newArgs...)
newArgs, changed := pushNotAcrossArgs(f.GetCtx(), f.GetArgs(), false)
if !changed {
return f, false
}
newArgs := doPushDownNot(f.GetCtx(), f.GetArgs(), false)
return NewFunctionInternal(f.GetCtx(), f.FuncName.L, f.GetType(), newArgs...)
case ast.LogicOr:
return NewFunctionInternal(f.GetCtx(), f.FuncName.L, f.GetType(), newArgs...), true
case ast.LogicAnd, ast.LogicOr:
var (
newArgs []Expression
changed bool
)
funcName := f.FuncName.L
if not {
newArgs := doPushDownNot(f.GetCtx(), f.GetArgs(), true)
return NewFunctionInternal(f.GetCtx(), ast.LogicAnd, f.GetType(), newArgs...)
newArgs, _ = pushNotAcrossArgs(f.GetCtx(), f.GetArgs(), true)
funcName = oppositeOp[f.FuncName.L]
changed = true
} else {
newArgs, changed = pushNotAcrossArgs(f.GetCtx(), f.GetArgs(), false)
}
if !changed {
return f, false
}
newArgs := doPushDownNot(f.GetCtx(), f.GetArgs(), false)
return NewFunctionInternal(f.GetCtx(), f.FuncName.L, f.GetType(), newArgs...)
return NewFunctionInternal(f.GetCtx(), funcName, f.GetType(), newArgs...), true
}
}
if not {
expr = NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), expr)
}
return expr
return expr, not
}

// PushDownNot pushes the `not` function down to the expression's arguments.
func PushDownNot(ctx sessionctx.Context, expr Expression) Expression {
newExpr, _ := pushNotAcrossExpr(ctx, expr, false)
return newExpr
}

// Contains tests if `exprs` contains `e`.
Expand Down
31 changes: 30 additions & 1 deletion expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,38 @@ func (s *testUtilSuite) TestPushDownNot(c *check.C) {
andFunc2 := newFunction(ast.LogicAnd, neFunc, neFunc)
orFunc2 := newFunction(ast.LogicOr, andFunc2, neFunc)
notFuncCopy := notFunc.Clone()
ret := PushDownNot(ctx, notFunc, false)
ret := PushDownNot(ctx, notFunc)
c.Assert(ret.Equal(ctx, orFunc2), check.IsTrue)
c.Assert(notFunc.Equal(ctx, notFuncCopy), check.IsTrue)

// issue 15725
// (not not a) should be optimized to (a is true)
notFunc = newFunction(ast.UnaryNot, col)
notFunc = newFunction(ast.UnaryNot, notFunc)
ret = PushDownNot(ctx, notFunc)
c.Assert(ret.Equal(ctx, newFunction(ast.IsTruth, col)), check.IsTrue)

// (not not (a+1)) should be optimized to (a+1 is true)
plusFunc := newFunction(ast.Plus, col, One)
notFunc = newFunction(ast.UnaryNot, plusFunc)
notFunc = newFunction(ast.UnaryNot, notFunc)
ret = PushDownNot(ctx, notFunc)
c.Assert(ret.Equal(ctx, newFunction(ast.IsTruth, plusFunc)), check.IsTrue)

// (not not not a) should be optimized to (not (a is true))
notFunc = newFunction(ast.UnaryNot, col)
notFunc = newFunction(ast.UnaryNot, notFunc)
notFunc = newFunction(ast.UnaryNot, notFunc)
ret = PushDownNot(ctx, notFunc)
c.Assert(ret.Equal(ctx, newFunction(ast.UnaryNot, newFunction(ast.IsTruth, col))), check.IsTrue)

// (not not not not a) should be optimized to (a is true)
notFunc = newFunction(ast.UnaryNot, col)
notFunc = newFunction(ast.UnaryNot, notFunc)
notFunc = newFunction(ast.UnaryNot, notFunc)
notFunc = newFunction(ast.UnaryNot, notFunc)
ret = PushDownNot(ctx, notFunc)
c.Assert(ret.Equal(ctx, newFunction(ast.IsTruth, col)), check.IsTrue)
}

func (s *testUtilSuite) TestFilter(c *check.C) {
Expand Down
2 changes: 1 addition & 1 deletion planner/core/rule_predicate_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func simplifyOuterJoin(p *LogicalJoin, predicates []expression.Expression) {
// If it is a conjunction containing a null-rejected condition as a conjunct.
// If it is a disjunction of null-rejected conditions.
func isNullRejected(ctx sessionctx.Context, schema *expression.Schema, expr expression.Expression) bool {
expr = expression.PushDownNot(nil, expr, false)
expr = expression.PushDownNot(ctx, expr)
sc := ctx.GetSessionVars().StmtCtx
sc.InNullRejectCheck = true
result := expression.EvaluateExprWithNull(ctx, schema, expr)
Expand Down
2 changes: 1 addition & 1 deletion planner/core/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (ds *DataSource) deriveStatsByFilter(conds expression.CNFExprs) {
func (ds *DataSource) DeriveStats(childStats []*property.StatsInfo) (*property.StatsInfo, error) {
// PushDownNot here can convert query 'not (a != 1)' to 'a = 1'.
for i, expr := range ds.pushedDownConds {
ds.pushedDownConds[i] = expression.PushDownNot(nil, expr, false)
ds.pushedDownConds[i] = expression.PushDownNot(ds.ctx, expr)
}
ds.deriveStatsByFilter(ds.pushedDownConds)
for _, path := range ds.possibleAccessPaths {
Expand Down
2 changes: 1 addition & 1 deletion planner/core/testdata/plan_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@
},
{
"SQL": "select a from t where not a",
"Best": "IndexReader(Index(t.f)[[NULL,+inf]]->Sel([not(test.t.a)]))"
"Best": "TableReader(Table(t))"
},
{
"SQL": "select a from t where c in (1)",
Expand Down
8 changes: 4 additions & 4 deletions util/ranger/ranger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ func (s *testRangerSuite) TestTableRange(c *C) {
selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection)
conds := make([]expression.Expression, 0, len(selection.Conditions))
for _, cond := range selection.Conditions {
conds = append(conds, expression.PushDownNot(sctx, cond, false))
conds = append(conds, expression.PushDownNot(sctx, cond))
}
tbl := selection.Children()[0].(*plannercore.DataSource).TableInfo()
col := expression.ColInfo2Col(selection.Schema().Columns, tbl.Columns[0])
Expand Down Expand Up @@ -603,7 +603,7 @@ func (s *testRangerSuite) TestIndexRange(c *C) {
c.Assert(selection, NotNil, Commentf("expr:%v", tt.exprStr))
conds := make([]expression.Expression, 0, len(selection.Conditions))
for _, cond := range selection.Conditions {
conds = append(conds, expression.PushDownNot(sctx, cond, false))
conds = append(conds, expression.PushDownNot(sctx, cond))
}
cols, lengths := expression.IndexInfo2Cols(selection.Schema().Columns, tbl.Indices[tt.indexPos])
c.Assert(cols, NotNil)
Expand Down Expand Up @@ -724,7 +724,7 @@ func (s *testRangerSuite) TestIndexRangeForUnsignedInt(c *C) {
c.Assert(selection, NotNil, Commentf("expr:%v", tt.exprStr))
conds := make([]expression.Expression, 0, len(selection.Conditions))
for _, cond := range selection.Conditions {
conds = append(conds, expression.PushDownNot(sctx, cond, false))
conds = append(conds, expression.PushDownNot(sctx, cond))
}
cols, lengths := expression.IndexInfo2Cols(selection.Schema().Columns, tbl.Indices[tt.indexPos])
c.Assert(cols, NotNil)
Expand Down Expand Up @@ -1055,7 +1055,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) {
c.Assert(ok, IsTrue, Commentf("expr:%v", tt.exprStr))
conds := make([]expression.Expression, 0, len(sel.Conditions))
for _, cond := range sel.Conditions {
conds = append(conds, expression.PushDownNot(sctx, cond, false))
conds = append(conds, expression.PushDownNot(sctx, cond))
}
col := expression.ColInfo2Col(sel.Schema().Columns, ds.TableInfo().Columns[tt.colPos])
c.Assert(col, NotNil)
Expand Down

0 comments on commit bbc3206

Please sign in to comment.