Skip to content

Commit

Permalink
expression: fix expression that duration type null equal with const n…
Browse files Browse the repository at this point in the history
…ull (#56768) (#56787)

close #56744
  • Loading branch information
ti-chi-bot authored Oct 23, 2024
1 parent 52b0b58 commit 27e78ae
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
14 changes: 11 additions & 3 deletions pkg/expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -1603,7 +1603,7 @@ func matchRefineRule3Pattern(conEvalType types.EvalType, exprType *types.FieldTy
(conEvalType == types.ETReal || conEvalType == types.ETDecimal || conEvalType == types.ETInt)
}

// handleDurationTypeComparison handles comparisons between a duration type column and a non-duration type constant.
// handleDurationTypeComparisonForNullEq handles comparisons between a duration type column and a non-duration type constant.
// If the constant cannot be cast to a duration type and the comparison operator is `<=>`, the expression is rewritten as `0 <=> 1`.
// This is necessary to maintain compatibility with MySQL behavior under the following conditions:
// 1. When a duration type column is compared with a non-duration type constant, MySQL casts the duration column to the non-duration type.
Expand All @@ -1616,7 +1616,7 @@ func matchRefineRule3Pattern(conEvalType types.EvalType, exprType *types.FieldTy
//
// To ensure MySQL compatibility, we need to handle this case specifically. If the non-duration type constant cannot be cast to a duration type,
// we rewrite the expression to always return false by converting it to `0 <=> 1`.
func (c *compareFunctionClass) handleDurationTypeComparison(ctx BuildContext, arg0, arg1 Expression) (_ []Expression, err error) {
func (c *compareFunctionClass) handleDurationTypeComparisonForNullEq(ctx BuildContext, arg0, arg1 Expression) (_ []Expression, err error) {
// check if a constant value becomes null after being cast to a duration type.
castToDurationIsNull := func(ctx BuildContext, arg Expression) (bool, error) {
f := WrapWithCastAsDuration(ctx, arg)
Expand All @@ -1632,8 +1632,16 @@ func (c *compareFunctionClass) handleDurationTypeComparison(ctx BuildContext, ar

var isNull bool
if arg0IsCon && arg0Const.DeferredExpr == nil && !arg1IsCon && arg1.GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeDuration {
if arg0Const.Value.IsNull() {
// This is a const null, there is no need to re-write the expression
return nil, nil
}
isNull, err = castToDurationIsNull(ctx, arg0)
} else if arg1IsCon && arg1Const.DeferredExpr == nil && !arg0IsCon && arg0.GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeDuration {
if arg1Const.Value.IsNull() {
// This is a const null, there is no need to re-write the expression
return nil, nil
}
isNull, err = castToDurationIsNull(ctx, arg1)
}
if err != nil {
Expand Down Expand Up @@ -1724,7 +1732,7 @@ func (c *compareFunctionClass) refineArgs(ctx BuildContext, args []Expression) (

// Handle comparison between a duration type column and a non-duration type constant.
if c.op == opcode.NullEQ {
if result, err := c.handleDurationTypeComparison(ctx, args[0], args[1]); err != nil || result != nil {
if result, err := c.handleDurationTypeComparisonForNullEq(ctx, args[0], args[1]); err != nil || result != nil {
return result, err
}
}
Expand Down
9 changes: 9 additions & 0 deletions pkg/expression/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3881,6 +3881,15 @@ func TestIssue51842(t *testing.T) {
require.Equal(t, 0, len(res))
res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast('2024-1-1 10:10:10' as datetime);").String() // test datetime
require.Equal(t, 0, len(res))

// Test issue 56744
tk.MustExec("drop table if exists lrr;")
tk.MustExec("create table lrr(`COL1` time DEFAULT NULL,`COL2` time DEFAULT NULL);")
tk.MustExec("insert into lrr(col2) values('-229:53:34');")
resultRows := tk.MustQuery("select * from lrr where col1 <=> null;").Rows() // test const null
require.Equal(t, 1, len(resultRows))
resultRows = tk.MustQuery("select * from lrr where null <=> col1;").Rows() // test const null
require.Equal(t, 1, len(resultRows))
}

func TestIssue44706(t *testing.T) {
Expand Down

0 comments on commit 27e78ae

Please sign in to comment.