Skip to content

Commit

Permalink
expression: check if period is valid in period_add (#10430)
Browse files Browse the repository at this point in the history
  • Loading branch information
qw4990 authored and ngaut committed May 15, 2019
1 parent 205418a commit 08c4559
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 10 deletions.
14 changes: 10 additions & 4 deletions expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -4887,6 +4887,11 @@ func (c *periodAddFunctionClass) getFunction(ctx sessionctx.Context, args []Expr
return sig, nil
}

// validPeriod checks if this period is valid, it comes from MySQL 8.0+.
func validPeriod(p int64) bool {
return !(p < 0 || p%100 == 0 || p%100 > 12)
}

// period2Month converts a period to months, in which period is represented in the format of YYMM or YYYYMM.
// Note that the period argument is not a date value.
func period2Month(period uint64) uint64 {
Expand Down Expand Up @@ -4938,15 +4943,16 @@ func (b *builtinPeriodAddSig) evalInt(row chunk.Row) (int64, bool, error) {
return 0, true, errors.Trace(err)
}

if p == 0 {
return 0, false, nil
}

n, isNull, err := b.args[1].EvalInt(b.ctx, row)
if isNull || err != nil {
return 0, true, errors.Trace(err)
}

// in MySQL, if p is invalid but n is NULL, the result is NULL, so we have to check if n is NULL first.
if !validPeriod(p) {
return 0, false, errIncorrectArgs.GenWithStackByArgs("period_add")
}

sumMonth := int64(period2Month(uint64(p))) + n
return int64(month2Period(uint64(sumMonth))), false, nil
}
Expand Down
4 changes: 2 additions & 2 deletions expression/builtin_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2146,8 +2146,8 @@ func (s *testEvaluatorSuite) TestPeriodAdd(c *C) {
{201611, -13, true, 201510},
{1611, 3, true, 201702},
{7011, 3, true, 197102},
{12323, 10, true, 12509},
{0, 3, true, 0},
{12323, 10, false, 0},
{0, 3, false, 0},
}

fc := funcs[ast.PeriodAdd]
Expand Down
14 changes: 10 additions & 4 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1450,10 +1450,16 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) {
result.Check(testkit.Rows("123456 10 <nil> <nil>"))

// for period_add
result = tk.MustQuery(`SELECT period_add(191, 2), period_add(191, -2), period_add(0, 20), period_add(0, 0);`)
result.Check(testkit.Rows("200809 200805 0 0"))
result = tk.MustQuery(`SELECT period_add(NULL, 2), period_add(-191, NULL), period_add(NULL, NULL), period_add(12.09, -2), period_add("21aa", "11aa"), period_add("", "");`)
result.Check(testkit.Rows("<nil> <nil> <nil> 200010 200208 0"))
result = tk.MustQuery(`SELECT period_add(200807, 2), period_add(200807, -2);`)
result.Check(testkit.Rows("200809 200805"))
result = tk.MustQuery(`SELECT period_add(NULL, 2), period_add(-191, NULL), period_add(NULL, NULL), period_add(12.09, -2), period_add("200207aa", "1aa");`)
result.Check(testkit.Rows("<nil> <nil> <nil> 200010 200208"))
for _, errPeriod := range []string{
"period_add(0, 20)", "period_add(0, 0)", "period_add(-1, 1)", "period_add(200013, 1)", "period_add(-200012, 1)", "period_add('', '')",
} {
err := tk.QueryToErr(fmt.Sprintf("SELECT %v;", errPeriod))
c.Assert(err.Error(), Equals, "[expression:1210]Incorrect arguments to period_add")
}

// for period_diff
result = tk.MustQuery(`SELECT period_diff(191, 2), period_diff(191, -2), period_diff(0, 0), period_diff(191, 191);`)
Expand Down
11 changes: 11 additions & 0 deletions util/testkit/testkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,17 @@ func (tk *TestKit) MustQuery(sql string, args ...interface{}) *Result {
return tk.ResultSetToResult(rs, comment)
}

// QueryToErr executes a sql statement and discard results.
func (tk *TestKit) QueryToErr(sql string, args ...interface{}) error {
comment := check.Commentf("sql:%s, args:%v", sql, args)
res, err := tk.Exec(sql, args...)
tk.c.Assert(errors.ErrorStack(err), check.Equals, "", comment)
tk.c.Assert(res, check.NotNil, comment)
_, resErr := session.GetRows4Test(context.Background(), tk.Se, res)
tk.c.Assert(res.Close(), check.IsNil)
return resErr
}

// ResultSetToResult converts sqlexec.RecordSet to testkit.Result.
// It is used to check results of execute statement in binary mode.
func (tk *TestKit) ResultSetToResult(rs sqlexec.RecordSet, comment check.CommentInterface) *Result {
Expand Down

0 comments on commit 08c4559

Please sign in to comment.