diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 8b2a772565420..f63ea4fe483e7 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -555,17 +555,24 @@ func (c *repeatFunctionClass) getFunction(ctx sessionctx.Context, args []Express bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETInt) bf.tp.Flen = mysql.MaxBlobWidth SetBinFlagOrBinStr(args[0].GetType(), bf.tp) - sig := &builtinRepeatSig{bf} + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + sig := &builtinRepeatSig{bf, maxAllowedPacket} return sig, nil } type builtinRepeatSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinRepeatSig) Clone() builtinFunc { newSig := &builtinRepeatSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -576,6 +583,7 @@ func (b *builtinRepeatSig) evalString(row chunk.Row) (d string, isNull bool, err if isNull || err != nil { return "", isNull, errors.Trace(err) } + byteLength := len(str) num, isNull, err := b.args[1].EvalInt(b.ctx, row) if isNull || err != nil { @@ -588,7 +596,12 @@ func (b *builtinRepeatSig) evalString(row chunk.Row) (d string, isNull bool, err num = math.MaxInt32 } - if int64(len(str)) > int64(b.tp.Flen)/num { + if uint64(byteLength)*uint64(num) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("repeat", b.maxAllowedPacket)) + return "", true, nil + } + + if int64(byteLength) > int64(b.tp.Flen)/num { return "", true, nil } return strings.Repeat(str, int(num)), false, nil diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 2f7b9d4e0ab93..09e6d961e4fd1 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -399,6 +399,51 @@ func (s *testEvaluatorSuite) TestRepeat(c *C) { c.Assert(v.GetString(), Equals, "") } +func (s *testEvaluatorSuite) TestRepeatSig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeLonglong}, + } + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000} + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + &Column{Index: 1, RetType: colTypes[1]}, + } + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + repeat := &builtinRepeatSig{base, 1000} + + cases := []struct { + args []interface{} + warning int + res string + }{ + {[]interface{}{"a", int64(6)}, 0, "aaaaaa"}, + {[]interface{}{"a", int64(10001)}, 1, ""}, + {[]interface{}{"毅", int64(6)}, 0, "毅毅毅毅毅毅"}, + {[]interface{}{"毅", int64(334)}, 2, ""}, + } + + for _, t := range cases { + input := chunk.NewChunkWithCapacity(colTypes, 10) + input.AppendString(0, t.args[0].(string)) + input.AppendInt64(1, t.args[1].(int64)) + + res, isNull, err := repeat.evalString(input.GetRow(0)) + c.Assert(res, Equals, t.res) + c.Assert(err, IsNil) + if t.warning == 0 { + c.Assert(isNull, IsFalse) + } else { + c.Assert(isNull, IsTrue) + c.Assert(err, IsNil) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, t.warning) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) + } + } +} + func (s *testEvaluatorSuite) TestLower(c *C) { defer testleak.AfterTest(c)() cases := []struct {