diff --git a/executor/executor.go b/executor/executor.go index 7e6009ac72166..5ba8865073dab 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -909,7 +909,7 @@ func (e *SelectionExec) Next(ctx context.Context, chk *chunk.Chunk) error { func (e *SelectionExec) unBatchedNext(ctx context.Context, chk *chunk.Chunk) error { for { for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() { - selected, err := expression.EvalBool(e.ctx, e.filters, e.inputRow) + selected, _, err := expression.EvalBool(e.ctx, e.filters, e.inputRow) if err != nil { return errors.Trace(err) } diff --git a/executor/index_lookup_join.go b/executor/index_lookup_join.go index d0e83c4f25f88..4927bca36e751 100644 --- a/executor/index_lookup_join.go +++ b/executor/index_lookup_join.go @@ -234,7 +234,7 @@ func (e *IndexLookUpJoin) Next(ctx context.Context, chk *chunk.Chunk) error { outerRow := task.outerResult.GetRow(task.cursor) if e.innerIter.Current() != e.innerIter.End() { - matched, err := e.joiner.tryToMatch(outerRow, e.innerIter, chk) + matched, _, err := e.joiner.tryToMatch(outerRow, e.innerIter, chk) if err != nil { return errors.Trace(err) } @@ -242,7 +242,7 @@ func (e *IndexLookUpJoin) Next(ctx context.Context, chk *chunk.Chunk) error { } if e.innerIter.Current() == e.innerIter.End() { if !task.hasMatch { - e.joiner.onMissMatch(outerRow, chk) + e.joiner.onMissMatch(false, outerRow, chk) } task.cursor++ task.hasMatch = false diff --git a/executor/join.go b/executor/join.go index 3f64e7a318cd0..eed268d096f52 100644 --- a/executor/join.go +++ b/executor/join.go @@ -433,13 +433,13 @@ func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.R return false, joinResult } if hasNull { - e.joiners[workerID].onMissMatch(outerRow, joinResult.chk) + e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) return true, joinResult } e.hashTableValBufs[workerID] = e.hashTable.Get(joinKey, e.hashTableValBufs[workerID][:0]) innerPtrs := e.hashTableValBufs[workerID] if len(innerPtrs) == 0 { - e.joiners[workerID].onMissMatch(outerRow, joinResult.chk) + e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) return true, joinResult } innerRows := make([]chunk.Row, 0, len(innerPtrs)) @@ -451,7 +451,7 @@ func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.R iter := chunk.NewIterator4Slice(innerRows) hasMatch := false for iter.Begin(); iter.Current() != iter.End(); { - matched, err := e.joiners[workerID].tryToMatch(outerRow, iter, joinResult.chk) + matched, _, err := e.joiners[workerID].tryToMatch(outerRow, iter, joinResult.chk) if err != nil { joinResult.err = errors.Trace(err) return false, joinResult @@ -468,7 +468,7 @@ func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.R } } if !hasMatch { - e.joiners[workerID].onMissMatch(outerRow, joinResult.chk) + e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) } return true, joinResult } @@ -496,7 +496,7 @@ func (e *HashJoinExec) join2Chunk(workerID uint, outerChk *chunk.Chunk, joinResu } for i := range selected { if !selected[i] { // process unmatched outer rows - e.joiners[workerID].onMissMatch(outerChk.GetRow(i), joinResult.chk) + e.joiners[workerID].onMissMatch(false, outerChk.GetRow(i), joinResult.chk) } else { // process matched outer rows ok, joinResult = e.joinMatchedOuterRow2Chunk(workerID, outerChk.GetRow(i), joinResult) if !ok { @@ -634,6 +634,7 @@ type NestedLoopApplyExec struct { innerIter chunk.Iterator outerRow *chunk.Row hasMatch bool + isNull bool memTracker *memory.Tracker // track memory usage. } @@ -691,7 +692,7 @@ func (e *NestedLoopApplyExec) fetchSelectedOuterRow(ctx context.Context, chk *ch if selected { return &outerRow, nil } else if e.outer { - e.joiner.onMissMatch(outerRow, chk) + e.joiner.onMissMatch(false, outerRow, chk) if chk.NumRows() == e.maxChunkSize { return nil, nil } @@ -739,13 +740,14 @@ func (e *NestedLoopApplyExec) Next(ctx context.Context, chk *chunk.Chunk) (err e for { if e.innerIter == nil || e.innerIter.Current() == e.innerIter.End() { if e.outerRow != nil && !e.hasMatch { - e.joiner.onMissMatch(*e.outerRow, chk) + e.joiner.onMissMatch(e.isNull, *e.outerRow, chk) } e.outerRow, err = e.fetchSelectedOuterRow(ctx, chk) if e.outerRow == nil || err != nil { return errors.Trace(err) } e.hasMatch = false + e.isNull = false for _, col := range e.outerSchema { *col.Data = e.outerRow.GetDatum(col.Index, col.RetType) @@ -758,8 +760,9 @@ func (e *NestedLoopApplyExec) Next(ctx context.Context, chk *chunk.Chunk) (err e e.innerIter.Begin() } - matched, err := e.joiner.tryToMatch(*e.outerRow, e.innerIter, chk) + matched, isNull, err := e.joiner.tryToMatch(*e.outerRow, e.innerIter, chk) e.hasMatch = e.hasMatch || matched + e.isNull = e.isNull || isNull if err != nil || chk.NumRows() == e.maxChunkSize { return errors.Trace(err) diff --git a/executor/join_test.go b/executor/join_test.go index e51fd86df9ab4..14d3b62f4c5cf 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -988,3 +988,92 @@ func (s *testSuite2) TestHashJoin(c *C) { innerExecInfo = row[3][4].(string) c.Assert(innerExecInfo[len(innerExecInfo)-1:], LessEqual, "5") } + +func (s *testSuite) TestNotInAntiJoin(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, c int)") + tk.MustExec("insert into t values(null, 1, 0), (1, 2, 0)") + tk.MustQuery("select a, b from t t1 where a not in (select b from t t2)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a not in (select b from t t2 where t1.b = t2.a)").Check(testkit.Rows( + "1 2", + )) + tk.MustQuery("select a, b from t t1 where a not in (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a not in (select a from t t2 where t1.b = t2.b)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a != all (select b from t t2)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a != all (select b from t t2 where t1.b = t2.a)").Check(testkit.Rows( + "1 2", + )) + tk.MustQuery("select a, b from t t1 where a != all (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a != all (select a from t t2 where t1.b = t2.b)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where not exists (select * from t t2 where t1.a = t2.b)").Check(testkit.Rows( + " 1", + )) + tk.MustQuery("select a, b from t t1 where not exists (select * from t t2 where t1.a = t2.a)").Check(testkit.Rows( + " 1", + )) + tk.MustExec("truncate table t") + tk.MustExec("insert into t values(1, null, 0), (2, 1, 0)") + tk.MustQuery("select a, b from t t1 where a not in (select b from t t2)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a not in (select b from t t2 where t1.b = t2.a)").Check(testkit.Rows( + "1 ", + )) + tk.MustQuery("select a, b from t t1 where a not in (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a not in (select a from t t2 where t1.b = t2.b)").Check(testkit.Rows( + "1 ", + )) + tk.MustQuery("select a, b from t t1 where a != all (select b from t t2)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a != all (select b from t t2 where t1.b = t2.a)").Check(testkit.Rows( + "1 ", + )) + tk.MustQuery("select a, b from t t1 where a != all (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a != all (select a from t t2 where t1.b = t2.b)").Check(testkit.Rows( + "1 ", + )) + tk.MustQuery("select a, b from t t1 where not exists (select * from t t2 where t1.a = t2.b)").Check(testkit.Rows( + "2 1", + )) + tk.MustQuery("select a, b from t t1 where not exists (select * from t t2 where t1.a = t2.a)").Check(testkit.Rows()) + tk.MustExec("truncate table t") + tk.MustExec("insert into t values(1, null, 0), (2, 1, 0), (null, 2, 0)") + tk.MustQuery("select a, b from t t1 where a not in (select b from t t2)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a not in (select b from t t2 where t1.b = t2.a)").Check(testkit.Rows( + "1 ", + )) + tk.MustQuery("select a, b from t t1 where a not in (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a not in (select a from t t2 where t1.b = t2.b)").Check(testkit.Rows( + "1 ", + )) + tk.MustQuery("select a, b from t t1 where a != all (select b from t t2)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a != all (select b from t t2 where t1.b = t2.a)").Check(testkit.Rows( + "1 ", + )) + tk.MustQuery("select a, b from t t1 where a != all (select a from t t2)").Check(testkit.Rows()) + tk.MustQuery("select a, b from t t1 where a != all (select a from t t2 where t1.b = t2.b)").Check(testkit.Rows( + "1 ", + )) + tk.MustQuery("select a, b from t t1 where not exists (select * from t t2 where t1.a = t2.b)").Check(testkit.Rows( + " 2", + )) + tk.MustQuery("select a, b from t t1 where not exists (select * from t t2 where t1.a = t2.a)").Check(testkit.Rows( + " 2", + )) + tk.MustExec("truncate table t") + tk.MustExec("insert into t values(1, null, 0), (2, null, 0)") + tk.MustQuery("select a, b from t t1 where b not in (select a from t t2)").Check(testkit.Rows()) + tk.MustExec("truncate table t") + tk.MustExec("insert into t values(null, 1, 1), (2, 2, 2), (3, null, 3), (4, 4, 3)") + tk.MustQuery("select a, b, a not in (select b from t) from t order by a").Check(testkit.Rows( + " 1 ", + "2 2 0", + "3 ", + "4 4 0", + )) + tk.MustQuery("select a, c, a not in (select c from t) from t order by a").Check(testkit.Rows( + " 1 ", + "2 2 0", + "3 3 0", + "4 3 1", + )) +} diff --git a/executor/joiner.go b/executor/joiner.go index 5423acdca86f8..93633a620d66e 100644 --- a/executor/joiner.go +++ b/executor/joiner.go @@ -35,14 +35,15 @@ var ( // joiner is used to generate join results according to the join type. // A typical instruction flow is: // -// hasMatch := false +// hasMatch, hasNull := false, false // for innerIter.Current() != innerIter.End() { -// matched, err := j.tryToMatch(outer, innerIter, chk) +// matched, isNull, err := j.tryToMatch(outer, innerIter, chk) // // handle err // hasMatch = hasMatch || matched +// hasNull = hasNull || isNull // } // if !hasMatch { -// j.onMissMatch(outer) +// j.onMissMatch(hasNull, outer, chk) // } // // NOTE: This interface is **not** thread-safe. @@ -55,7 +56,7 @@ type joiner interface { // NOTE: Callers need to call this function multiple times to consume all // the inner rows for an outer row, and dicide whether the outer row can be // matched with at lease one inner row. - tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, error) + tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, bool, error) // onMissMatch operates on the unmatched outer row according to the join // type. An outer row can be considered miss matched if: @@ -76,7 +77,11 @@ type joiner interface { // 6. 'RightOuterJoin': concats the unmatched outer row with a row of NULLs // and appends it to the result buffer. // 7. 'InnerJoin': ignores the unmatched outer row. - onMissMatch(outer chunk.Row, chk *chunk.Chunk) + // Not that, for anti join, we need to know the reason of outer row treated as + // unmatched: whether the join condition returns false, or returns null, because + // it decides if this outer row should be outputed, hence we have a `hasNull` + // parameter passed to `onMissMatch`. + onMissMatch(hasNull bool, outer chunk.Row, chk *chunk.Chunk) } func newJoiner(ctx sessionctx.Context, joinType plannercore.JoinType, @@ -176,34 +181,34 @@ type semiJoiner struct { baseJoiner } -func (j *semiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, err error) { +func (j *semiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } if len(j.conditions) == 0 { chk.AppendPartialRow(0, outer) inners.ReachEnd() - return true, nil + return true, false, nil } for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { j.makeShallowJoinRow(j.outerIsRight, inner, outer) - matched, err = expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, _, err = expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } if matched { chk.AppendPartialRow(0, outer) inners.ReachEnd() - return true, nil + return true, false, nil } } - return false, nil + return false, false, nil } -func (j *semiJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { +func (j *semiJoiner) onMissMatch(hasNull bool, outer chunk.Row, chk *chunk.Chunk) { } type antiSemiJoiner struct { @@ -211,33 +216,36 @@ type antiSemiJoiner struct { } // tryToMatch implements joiner interface. -func (j *antiSemiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, err error) { +func (j *antiSemiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } if len(j.conditions) == 0 { inners.ReachEnd() - return true, nil + return true, false, nil } for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { j.makeShallowJoinRow(j.outerIsRight, inner, outer) - matched, err = expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, isNull, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } if matched { inners.ReachEnd() - return true, nil + return true, false, nil } + hasNull = hasNull || isNull } - return false, nil + return false, hasNull, nil } -func (j *antiSemiJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { - chk.AppendRow(outer) +func (j *antiSemiJoiner) onMissMatch(hasNull bool, outer chunk.Row, chk *chunk.Chunk) { + if !hasNull { + chk.AppendRow(outer) + } } type leftOuterSemiJoiner struct { @@ -245,31 +253,31 @@ type leftOuterSemiJoiner struct { } // tryToMatch implements joiner interface. -func (j *leftOuterSemiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, err error) { +func (j *leftOuterSemiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } if len(j.conditions) == 0 { j.onMatch(outer, chk) inners.ReachEnd() - return true, nil + return true, false, nil } for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { j.makeShallowJoinRow(false, inner, outer) - matched, err = expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, _, err = expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } if matched { j.onMatch(outer, chk) inners.ReachEnd() - return true, nil + return true, false, nil } } - return false, nil + return false, false, nil } func (j *leftOuterSemiJoiner) onMatch(outer chunk.Row, chk *chunk.Chunk) { @@ -277,7 +285,7 @@ func (j *leftOuterSemiJoiner) onMatch(outer chunk.Row, chk *chunk.Chunk) { chk.AppendInt64(outer.Len(), 1) } -func (j *leftOuterSemiJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { +func (j *leftOuterSemiJoiner) onMissMatch(hasNull bool, outer chunk.Row, chk *chunk.Chunk) { chk.AppendPartialRow(0, outer) chk.AppendInt64(outer.Len(), 0) } @@ -287,31 +295,32 @@ type antiLeftOuterSemiJoiner struct { } // tryToMatch implements joiner interface. -func (j *antiLeftOuterSemiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, err error) { +func (j *antiLeftOuterSemiJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (matched bool, hasNull bool, err error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } if len(j.conditions) == 0 { j.onMatch(outer, chk) inners.ReachEnd() - return true, nil + return true, false, nil } for inner := inners.Current(); inner != inners.End(); inner = inners.Next() { j.makeShallowJoinRow(false, inner, outer) - matched, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) + matched, isNull, err := expression.EvalBool(j.ctx, j.conditions, j.shallowRow.ToRow()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } if matched { j.onMatch(outer, chk) inners.ReachEnd() - return true, nil + return true, false, nil } + hasNull = hasNull || isNull } - return false, nil + return false, hasNull, nil } func (j *antiLeftOuterSemiJoiner) onMatch(outer chunk.Row, chk *chunk.Chunk) { @@ -319,9 +328,13 @@ func (j *antiLeftOuterSemiJoiner) onMatch(outer chunk.Row, chk *chunk.Chunk) { chk.AppendInt64(outer.Len(), 0) } -func (j *antiLeftOuterSemiJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { +func (j *antiLeftOuterSemiJoiner) onMissMatch(hasNull bool, outer chunk.Row, chk *chunk.Chunk) { chk.AppendPartialRow(0, outer) - chk.AppendInt64(outer.Len(), 1) + if hasNull { + chk.AppendNull(outer.Len()) + } else { + chk.AppendInt64(outer.Len(), 1) + } } type leftOuterJoiner struct { @@ -329,9 +342,9 @@ type leftOuterJoiner struct { } // tryToMatch implements joiner interface. -func (j *leftOuterJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, error) { +func (j *leftOuterJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, bool, error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } j.chk.Reset() chkForJoin := j.chk @@ -345,18 +358,18 @@ func (j *leftOuterJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk inners.Next() } if len(j.conditions) == 0 { - return true, nil + return true, false, nil } // reach here, chkForJoin is j.chk matched, err := j.filter(chkForJoin, chk, outer.Len()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } - return matched, nil + return matched, false, nil } -func (j *leftOuterJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { +func (j *leftOuterJoiner) onMissMatch(hasNull bool, outer chunk.Row, chk *chunk.Chunk) { chk.AppendPartialRow(0, outer) chk.AppendPartialRow(outer.Len(), j.defaultInner) } @@ -366,9 +379,9 @@ type rightOuterJoiner struct { } // tryToMatch implements joiner interface. -func (j *rightOuterJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, error) { +func (j *rightOuterJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, bool, error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } j.chk.Reset() @@ -383,17 +396,17 @@ func (j *rightOuterJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, ch inners.Next() } if len(j.conditions) == 0 { - return true, nil + return true, false, nil } matched, err := j.filter(chkForJoin, chk, outer.Len()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } - return matched, nil + return matched, false, nil } -func (j *rightOuterJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { +func (j *rightOuterJoiner) onMissMatch(hasNull bool, outer chunk.Row, chk *chunk.Chunk) { chk.AppendPartialRow(0, j.defaultInner) chk.AppendPartialRow(j.defaultInner.Len(), outer) } @@ -403,9 +416,9 @@ type innerJoiner struct { } // tryToMatch implements joiner interface. -func (j *innerJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, error) { +func (j *innerJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *chunk.Chunk) (bool, bool, error) { if inners.Len() == 0 { - return false, nil + return false, false, nil } j.chk.Reset() chkForJoin := j.chk @@ -421,17 +434,17 @@ func (j *innerJoiner) tryToMatch(outer chunk.Row, inners chunk.Iterator, chk *ch } } if len(j.conditions) == 0 { - return true, nil + return true, false, nil } // reach here, chkForJoin is j.chk matched, err := j.filter(chkForJoin, chk, outer.Len()) if err != nil { - return false, errors.Trace(err) + return false, false, errors.Trace(err) } - return matched, nil + return matched, false, nil } -func (j *innerJoiner) onMissMatch(outer chunk.Row, chk *chunk.Chunk) { +func (j *innerJoiner) onMissMatch(hasNull bool, outer chunk.Row, chk *chunk.Chunk) { } diff --git a/executor/merge_join.go b/executor/merge_join.go index 14456561f826b..d15e0c5418f19 100644 --- a/executor/merge_join.go +++ b/executor/merge_join.go @@ -305,7 +305,7 @@ func (e *MergeJoinExec) joinToChunk(ctx context.Context, chk *chunk.Chunk) (hasM } if cmpResult < 0 { - e.joiner.onMissMatch(e.outerTable.row, chk) + e.joiner.onMissMatch(false, e.outerTable.row, chk) if err != nil { return false, errors.Trace(err) } @@ -319,7 +319,7 @@ func (e *MergeJoinExec) joinToChunk(ctx context.Context, chk *chunk.Chunk) (hasM continue } - matched, err := e.joiner.tryToMatch(e.outerTable.row, e.innerIter4Row, chk) + matched, _, err := e.joiner.tryToMatch(e.outerTable.row, e.innerIter4Row, chk) if err != nil { return false, errors.Trace(err) } @@ -327,7 +327,7 @@ func (e *MergeJoinExec) joinToChunk(ctx context.Context, chk *chunk.Chunk) (hasM if e.innerIter4Row.Current() == e.innerIter4Row.End() { if !e.outerTable.hasMatch { - e.joiner.onMissMatch(e.outerTable.row, chk) + e.joiner.onMissMatch(false, e.outerTable.row, chk) } e.outerTable.row = e.outerTable.iter.Next() e.innerIter4Row.Begin() diff --git a/executor/union_scan.go b/executor/union_scan.go index 187d8e0c7a112..ec96cada04ad7 100644 --- a/executor/union_scan.go +++ b/executor/union_scan.go @@ -286,7 +286,7 @@ func (us *UnionScanExec) buildAndSortAddedRows() error { } } mutableRow.SetDatums(newData...) - matched, err := expression.EvalBool(us.ctx, us.conditions, mutableRow.ToRow()) + matched, _, err := expression.EvalBool(us.ctx, us.conditions, mutableRow.ToRow()) if err != nil { return errors.Trace(err) } diff --git a/expression/chunk_executor.go b/expression/chunk_executor.go index 66d4b2ee08fe9..6ee5a722ccc78 100644 --- a/expression/chunk_executor.go +++ b/expression/chunk_executor.go @@ -252,7 +252,7 @@ func VectorizedFilter(ctx sessionctx.Context, filters []Expression, iterator *ch selected[row.Idx()] = selected[row.Idx()] && !isNull && (filterResult != 0) } else { // TODO: should rewrite the filter to `cast(expr as SIGNED) != 0` and always use `EvalInt`. - bVal, err := EvalBool(ctx, []Expression{filter}, row) + bVal, _, err := EvalBool(ctx, []Expression{filter}, row) if err != nil { return nil, err } diff --git a/expression/constant_propagation.go b/expression/constant_propagation.go index ec69cf62ef05f..9db1c936fbb15 100644 --- a/expression/constant_propagation.go +++ b/expression/constant_propagation.go @@ -245,7 +245,7 @@ func (s *propConstSolver) pickNewEQConds(visited []bool) (retMapper map[int]*Con var ok bool if col == nil { if con, ok = cond.(*Constant); ok { - value, err := EvalBool(s.ctx, []Expression{con}, chunk.Row{}) + value, _, err := EvalBool(s.ctx, []Expression{con}, chunk.Row{}) if err != nil { terror.Log(err) return nil @@ -334,7 +334,7 @@ func (s *propOuterJoinConstSolver) pickEQCondsOnOuterCol(retMapper map[int]*Cons var ok bool if col == nil { if con, ok = cond.(*Constant); ok { - value, err := EvalBool(s.ctx, []Expression{con}, chunk.Row{}) + value, _, err := EvalBool(s.ctx, []Expression{con}, chunk.Row{}) if err != nil { terror.Log(err) return nil diff --git a/expression/expression.go b/expression/expression.go index a2380088b6f34..f3e8757a0afd9 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -111,26 +111,34 @@ func (e CNFExprs) Clone() CNFExprs { return cnf } -// EvalBool evaluates expression list to a boolean value. -func EvalBool(ctx sessionctx.Context, exprList CNFExprs, row chunk.Row) (bool, error) { +// EvalBool evaluates expression list to a boolean value. The first returned value +// indicates bool result of the expression list, the second returned value indicates +// whether the result of the expression list is null, it can only be true when the +// first returned values is false. +func EvalBool(ctx sessionctx.Context, exprList CNFExprs, row chunk.Row) (bool, bool, error) { + hasNull := false for _, expr := range exprList { data, err := expr.Eval(row) if err != nil { - return false, err + return false, false, err } if data.IsNull() { - return false, nil + hasNull = true + continue } i, err := data.ToBool(ctx.GetSessionVars().StmtCtx) if err != nil { - return false, err + return false, false, err } if i == 0 { - return false, nil + return false, false, nil } } - return true, nil + if hasNull { + return false, true, nil + } + return true, false, nil } // composeConditionWithBinaryOp composes condition with binary operator into a balance deep tree, which benefits a lot for pb decoder/encoder. diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 2a4af05983ace..ae5bd6fd6ee1a 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -363,11 +363,11 @@ func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) if v.All { er.handleEQAll(lexpr, rexpr, np) } else { - er.p, er.err = er.b.buildSemiApply(er.p, np, []expression.Expression{condition}, er.asScalar, false) + er.p, er.err = er.b.buildSemiApply(er.p, np, []expression.Expression{condition}, er.asScalar, false, false) } } else if v.Op == opcode.NE { if v.All { - er.p, er.err = er.b.buildSemiApply(er.p, np, []expression.Expression{condition}, er.asScalar, true) + er.p, er.err = er.b.buildSemiApply(er.p, np, []expression.Expression{condition}, er.asScalar, true, true) } else { er.handleNEAny(lexpr, rexpr, np) } @@ -457,7 +457,7 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, // plan4Agg.buildProjectionIfNecessary() if !er.asScalar { // For Semi LogicalApply without aux column, the result is no matter false or null. So we can add it to join predicate. - er.p, er.err = er.b.buildSemiApply(er.p, plan4Agg, []expression.Expression{cond}, false, false) + er.p, er.err = er.b.buildSemiApply(er.p, plan4Agg, []expression.Expression{cond}, false, false, false) return } // If we treat the result as a scalar value, we will add a projection with a extra column to output true, false or null. @@ -545,7 +545,7 @@ func (er *expressionRewriter) handleExistSubquery(v *ast.ExistsSubqueryExpr) (as } np = er.popExistsSubPlan(np) if len(np.extractCorrelatedCols()) > 0 { - er.p, er.err = er.b.buildSemiApply(er.p, np, nil, er.asScalar, v.Not) + er.p, er.err = er.b.buildSemiApply(er.p, np, nil, er.asScalar, v.Not, false) if er.err != nil || !er.asScalar { return v, true } @@ -668,7 +668,7 @@ func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, } er.p = join } else { - er.p, er.err = er.b.buildSemiApply(er.p, np, expression.SplitCNFItems(checkCondition), asScalar, v.Not) + er.p, er.err = er.b.buildSemiApply(er.p, np, expression.SplitCNFItems(checkCondition), asScalar, v.Not, v.Not) if er.err != nil { return v, true } diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index df1cf71621744..37f935c8dc97c 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -180,7 +180,7 @@ func (ds *DataSource) tryToGetMemTask(prop *property.PhysicalProperty) (task tas func (ds *DataSource) tryToGetDualTask() (task, error) { for _, cond := range ds.pushedDownConds { if con, ok := cond.(*expression.Constant); ok && con.DeferredExpr == nil { - result, err := expression.EvalBool(ds.ctx, []expression.Expression{cond}, chunk.Row{}) + result, _, err := expression.EvalBool(ds.ctx, []expression.Expression{cond}, chunk.Row{}) if err != nil { return nil, errors.Trace(err) } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 72aa76dddb3e4..99ae8fed7af99 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -491,7 +491,7 @@ func (b *PlanBuilder) buildSelection(p LogicalPlan, where ast.ExprNode, AggMappe cnfItems := expression.SplitCNFItems(expr) for _, item := range cnfItems { if con, ok := item.(*expression.Constant); ok && con.DeferredExpr == nil { - ret, err := expression.EvalBool(b.ctx, expression.CNFExprs{con}, chunk.Row{}) + ret, _, err := expression.EvalBool(b.ctx, expression.CNFExprs{con}, chunk.Row{}) if err != nil || ret { continue } @@ -2105,8 +2105,28 @@ func (b *PlanBuilder) buildApplyWithJoinType(outerPlan, innerPlan LogicalPlan, t return ap } +func cannotDecorrelate(fromNotIn bool, outerPlan, innerPlan LogicalPlan, condition []expression.Expression) bool { + if !fromNotIn { + return false + } + outerSchema := outerPlan.Schema() + innerSchema := innerPlan.Schema() + cols := make([]*expression.Column, 0, 2*len(condition)) + cols = expression.ExtractColumnsFromExpressions(cols, condition, nil) + for _, col := range cols { + if innerCol := innerSchema.RetrieveColumn(col); innerCol != nil && !mysql.HasNotNullFlag(innerCol.RetType.Flag) { + return true + } + if outerCol := outerSchema.RetrieveColumn(col); outerCol != nil && !mysql.HasNotNullFlag(outerCol.RetType.Flag) { + return true + } + } + return false +} + // buildSemiApply builds apply plan with outerPlan and innerPlan, which apply semi-join for every row from outerPlan and the whole innerPlan. -func (b *PlanBuilder) buildSemiApply(outerPlan, innerPlan LogicalPlan, condition []expression.Expression, asScalar, not bool) (LogicalPlan, error) { +func (b *PlanBuilder) buildSemiApply(outerPlan, innerPlan LogicalPlan, condition []expression.Expression, + asScalar, not bool, fromNotIn bool) (LogicalPlan, error) { b.optFlag = b.optFlag | flagPredicatePushDown b.optFlag = b.optFlag | flagBuildKeyInfo b.optFlag = b.optFlag | flagDecorrelate @@ -2116,7 +2136,8 @@ func (b *PlanBuilder) buildSemiApply(outerPlan, innerPlan LogicalPlan, condition return nil, errors.Trace(err) } - ap := &LogicalApply{LogicalJoin: *join} + cannot := cannotDecorrelate(fromNotIn, outerPlan, innerPlan, condition) + ap := &LogicalApply{LogicalJoin: *join, cannotDecorrelate: cannot} ap.tp = TypeApply ap.self = ap return ap, nil diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 1ea9f82e35dd2..907529d42a8ec 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -1889,3 +1889,56 @@ func (s *testPlanSuite) TestSelectView(c *C) { c.Assert(ToString(p), Equals, tt.best, comment) } } + +func (s *testPlanSuite) TestNotInDecorrelate(c *C) { + defer func() { + testleak.AfterTest(c)() + }() + tests := []struct { + sql string + best string + }{ + { + sql: "select * from t t1 where a not in (select e from t t2)", + best: "Apply{DataScan(t1)->DataScan(t2)}->Projection", + }, + { + sql: "select * from t t1 where a != all (select e from t t2)", + best: "Apply{DataScan(t1)->DataScan(t2)}->Sel([5_aux_0])->Projection", + }, + // Column `b` has NotNull flag, so we can decorrelate this `not in`. + { + sql: "select * from t t1 where a not in (select b from t t2)", + best: "Join{DataScan(t1)->DataScan(t2)}(t1.a,t2.b)->Projection", + }, + { + sql: "select * from t t1 where a != all (select b from t t2)", + best: "Join{DataScan(t1)->DataScan(t2)}(t1.a,t2.b)->Sel([5_aux_0])->Projection", + }, + // Column `e` doesn't have NotNull flag, so we cannot decorrelate `not in`. + { + sql: "select * from t t1 where e not in (select a from t t2)", + best: "Apply{DataScan(t1)->DataScan(t2)}->Projection", + }, + { + sql: "select * from t t1 where e != all (select a from t t2)", + best: "Apply{DataScan(t1)->DataScan(t2)}->Sel([5_aux_0])->Projection", + }, + } + for i, tt := range tests { + comment := Commentf("case:%v sql:%s", i, tt.sql) + stmt, err := s.ParseOneStmt(tt.sql, "", "") + c.Assert(err, IsNil, comment) + Preprocess(s.ctx, stmt, s.is, false) + builder := &PlanBuilder{ + ctx: MockContext(), + is: s.is, + colMapper: make(map[*ast.ColumnNameExpr]int), + } + p, err := builder.Build(stmt) + c.Assert(err, IsNil) + p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) + c.Assert(err, IsNil) + c.Assert(ToString(p), Equals, tt.best, comment) + } +} diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index aa2cc510b67c9..b5d990c88d2d9 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -251,6 +251,12 @@ type LogicalApply struct { LogicalJoin corCols []*expression.CorrelatedColumn + // For LogicalApply generated from `not in` or `!= all`, if column from any side is possible to be null, + // we cannot decorrelate this apply, because during decorrelation, we may pull up equal conditions from + // descendant nodes, these conditions should have various behaviors against the equal conditions from + // `not in / != all` regarding null input, since we don't differentiate these conditions now, decorrelation + // can lead to wrong results. + cannotDecorrelate bool } func (la *LogicalApply) extractCorrelatedCols() []*expression.CorrelatedColumn { diff --git a/planner/core/rule_decorrelate.go b/planner/core/rule_decorrelate.go index f873cd5adc9ad..ebf415734e587 100644 --- a/planner/core/rule_decorrelate.go +++ b/planner/core/rule_decorrelate.go @@ -129,7 +129,7 @@ func (s *decorrelateSolver) aggDefaultValueMap(agg *LogicalAggregation) map[int] // optimize implements logicalOptRule interface. func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { - if apply, ok := p.(*LogicalApply); ok { + if apply, ok := p.(*LogicalApply); ok && !apply.cannotDecorrelate { outerPlan := apply.children[0] innerPlan := apply.children[1] apply.extractCorColumnsBySchema()