diff --git a/pkg/expression/builtin_compare.go b/pkg/expression/builtin_compare.go index 18a935b274829..f7b366f8aeda8 100644 --- a/pkg/expression/builtin_compare.go +++ b/pkg/expression/builtin_compare.go @@ -1603,7 +1603,7 @@ func matchRefineRule3Pattern(conEvalType types.EvalType, exprType *types.FieldTy (conEvalType == types.ETReal || conEvalType == types.ETDecimal || conEvalType == types.ETInt) } -// handleDurationTypeComparison handles comparisons between a duration type column and a non-duration type constant. +// handleDurationTypeComparisonForNullEq handles comparisons between a duration type column and a non-duration type constant. // If the constant cannot be cast to a duration type and the comparison operator is `<=>`, the expression is rewritten as `0 <=> 1`. // This is necessary to maintain compatibility with MySQL behavior under the following conditions: // 1. When a duration type column is compared with a non-duration type constant, MySQL casts the duration column to the non-duration type. @@ -1616,7 +1616,7 @@ func matchRefineRule3Pattern(conEvalType types.EvalType, exprType *types.FieldTy // // To ensure MySQL compatibility, we need to handle this case specifically. If the non-duration type constant cannot be cast to a duration type, // we rewrite the expression to always return false by converting it to `0 <=> 1`. -func (c *compareFunctionClass) handleDurationTypeComparison(ctx BuildContext, arg0, arg1 Expression) (_ []Expression, err error) { +func (c *compareFunctionClass) handleDurationTypeComparisonForNullEq(ctx BuildContext, arg0, arg1 Expression) (_ []Expression, err error) { // check if a constant value becomes null after being cast to a duration type. castToDurationIsNull := func(ctx BuildContext, arg Expression) (bool, error) { f := WrapWithCastAsDuration(ctx, arg) @@ -1632,8 +1632,16 @@ func (c *compareFunctionClass) handleDurationTypeComparison(ctx BuildContext, ar var isNull bool if arg0IsCon && arg0Const.DeferredExpr == nil && !arg1IsCon && arg1.GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeDuration { + if arg0Const.Value.IsNull() { + // This is a const null, there is no need to re-write the expression + return nil, nil + } isNull, err = castToDurationIsNull(ctx, arg0) } else if arg1IsCon && arg1Const.DeferredExpr == nil && !arg0IsCon && arg0.GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeDuration { + if arg1Const.Value.IsNull() { + // This is a const null, there is no need to re-write the expression + return nil, nil + } isNull, err = castToDurationIsNull(ctx, arg1) } if err != nil { @@ -1724,7 +1732,7 @@ func (c *compareFunctionClass) refineArgs(ctx BuildContext, args []Expression) ( // Handle comparison between a duration type column and a non-duration type constant. if c.op == opcode.NullEQ { - if result, err := c.handleDurationTypeComparison(ctx, args[0], args[1]); err != nil || result != nil { + if result, err := c.handleDurationTypeComparisonForNullEq(ctx, args[0], args[1]); err != nil || result != nil { return result, err } } diff --git a/pkg/expression/integration_test/integration_test.go b/pkg/expression/integration_test/integration_test.go index b35d3a74240f5..8b43a52534d25 100644 --- a/pkg/expression/integration_test/integration_test.go +++ b/pkg/expression/integration_test/integration_test.go @@ -3881,6 +3881,15 @@ func TestIssue51842(t *testing.T) { require.Equal(t, 0, len(res)) res = tk.MustQuery("SELECT f1 FROM (SELECT NULLIF(v0.c0, 1371581446) AS f1 FROM v0, t0) AS t WHERE f1 <=> cast('2024-1-1 10:10:10' as datetime);").String() // test datetime require.Equal(t, 0, len(res)) + + // Test issue 56744 + tk.MustExec("drop table if exists lrr;") + tk.MustExec("create table lrr(`COL1` time DEFAULT NULL,`COL2` time DEFAULT NULL);") + tk.MustExec("insert into lrr(col2) values('-229:53:34');") + resultRows := tk.MustQuery("select * from lrr where col1 <=> null;").Rows() // test const null + require.Equal(t, 1, len(resultRows)) + resultRows = tk.MustQuery("select * from lrr where null <=> col1;").Rows() // test const null + require.Equal(t, 1, len(resultRows)) } func TestIssue44706(t *testing.T) {