From ac8316488db2016bb14c91612a8a1e544a3b4c4f Mon Sep 17 00:00:00 2001 From: "hao.hu" Date: Thu, 9 Aug 2018 10:01:37 +0800 Subject: [PATCH 1/3] expression: handle max_allowed_packet warnings for repeat function. issue #7153 --- expression/builtin_string.go | 17 +++++++++++++++-- expression/builtin_string_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 4998aa24885c8..04a443f01e843 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -554,17 +554,24 @@ func (c *repeatFunctionClass) getFunction(ctx sessionctx.Context, args []Express bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETInt) bf.tp.Flen = mysql.MaxBlobWidth SetBinFlagOrBinStr(args[0].GetType(), bf.tp) - sig := &builtinRepeatSig{bf} + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + sig := &builtinRepeatSig{bf, maxAllowedPacket} return sig, nil } type builtinRepeatSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinRepeatSig) Clone() builtinFunc { newSig := &builtinRepeatSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -575,6 +582,7 @@ func (b *builtinRepeatSig) evalString(row chunk.Row) (d string, isNull bool, err if isNull || err != nil { return "", isNull, errors.Trace(err) } + byteLength := len(str) num, isNull, err := b.args[1].EvalInt(b.ctx, row) if isNull || err != nil { @@ -587,7 +595,12 @@ func (b *builtinRepeatSig) evalString(row chunk.Row) (d string, isNull bool, err num = math.MaxInt32 } - if int64(len(str)) > int64(b.tp.Flen)/num { + if uint64(byteLength)*uint64(num) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("repeat", b.maxAllowedPacket)) + return "", true, nil + } + + if int64(byteLength) > int64(b.tp.Flen)/num { return "", true, nil } return strings.Repeat(str, int(num)), false, nil diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 3d067da53a335..7df4830b7bb94 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -398,6 +398,37 @@ func (s *testEvaluatorSuite) TestRepeat(c *C) { c.Assert(v.GetString(), Equals, "") } +func (s *testEvaluatorSuite) TestRepeatSig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeLonglong}, + } + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000} + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + &Column{Index: 1, RetType: colTypes[1]}, + } + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + repeat := &builtinRepeatSig{base, 1000} + input := chunk.NewChunkWithCapacity(colTypes, 10) + input.AppendString(0, "a") + input.AppendString(0, "a") + input.AppendInt64(1, 6) + input.AppendInt64(1, 10001) + res, isNull, err := repeat.evalString(input.GetRow(0)) + c.Assert(res, Equals, "aaaaaa") + c.Assert(isNull, IsFalse) + c.Assert(err, IsNil) + res, isNull, err = repeat.evalString(input.GetRow(1)) + c.Assert(res, Equals, "") + c.Assert(isNull, IsTrue) + c.Assert(err, IsNil) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, 1) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) +} + func (s *testEvaluatorSuite) TestLower(c *C) { defer testleak.AfterTest(c)() cases := []struct { From 315104ebfdef3ae72a52749608355fee9f2a2242 Mon Sep 17 00:00:00 2001 From: "hao.hu" Date: Mon, 13 Aug 2018 22:10:23 +0800 Subject: [PATCH 2/3] add more test --- expression/builtin_string_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 7df4830b7bb94..bee5eb1da77a2 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -413,8 +413,12 @@ func (s *testEvaluatorSuite) TestRepeatSig(c *C) { input := chunk.NewChunkWithCapacity(colTypes, 10) input.AppendString(0, "a") input.AppendString(0, "a") + input.AppendString(0, "毅") + input.AppendString(0, "毅") input.AppendInt64(1, 6) input.AppendInt64(1, 10001) + input.AppendInt64(1, 6) + input.AppendInt64(1, 334) res, isNull, err := repeat.evalString(input.GetRow(0)) c.Assert(res, Equals, "aaaaaa") c.Assert(isNull, IsFalse) @@ -427,6 +431,18 @@ func (s *testEvaluatorSuite) TestRepeatSig(c *C) { c.Assert(len(warnings), Equals, 1) lastWarn := warnings[len(warnings)-1] c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) + res, isNull, err = repeat.evalString(input.GetRow(2)) + c.Assert(res, Equals, "毅毅毅毅毅毅") + c.Assert(isNull, IsFalse) + c.Assert(err, IsNil) + res, isNull, err = repeat.evalString(input.GetRow(3)) + c.Assert(res, Equals, "") + c.Assert(isNull, IsTrue) + c.Assert(err, IsNil) + warnings = s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, 2) + lastWarn = warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) } func (s *testEvaluatorSuite) TestLower(c *C) { From 8ca6783d9a850f0ef2062bcd15ebc7c6562a294b Mon Sep 17 00:00:00 2001 From: "hao.hu" Date: Wed, 15 Aug 2018 01:17:15 +0800 Subject: [PATCH 3/3] refine TestRepeatSig --- expression/builtin_string_test.go | 64 +++++++++++++++---------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index bee5eb1da77a2..235f1b2588e85 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -410,39 +410,37 @@ func (s *testEvaluatorSuite) TestRepeatSig(c *C) { } base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} repeat := &builtinRepeatSig{base, 1000} - input := chunk.NewChunkWithCapacity(colTypes, 10) - input.AppendString(0, "a") - input.AppendString(0, "a") - input.AppendString(0, "毅") - input.AppendString(0, "毅") - input.AppendInt64(1, 6) - input.AppendInt64(1, 10001) - input.AppendInt64(1, 6) - input.AppendInt64(1, 334) - res, isNull, err := repeat.evalString(input.GetRow(0)) - c.Assert(res, Equals, "aaaaaa") - c.Assert(isNull, IsFalse) - c.Assert(err, IsNil) - res, isNull, err = repeat.evalString(input.GetRow(1)) - c.Assert(res, Equals, "") - c.Assert(isNull, IsTrue) - c.Assert(err, IsNil) - warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() - c.Assert(len(warnings), Equals, 1) - lastWarn := warnings[len(warnings)-1] - c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) - res, isNull, err = repeat.evalString(input.GetRow(2)) - c.Assert(res, Equals, "毅毅毅毅毅毅") - c.Assert(isNull, IsFalse) - c.Assert(err, IsNil) - res, isNull, err = repeat.evalString(input.GetRow(3)) - c.Assert(res, Equals, "") - c.Assert(isNull, IsTrue) - c.Assert(err, IsNil) - warnings = s.ctx.GetSessionVars().StmtCtx.GetWarnings() - c.Assert(len(warnings), Equals, 2) - lastWarn = warnings[len(warnings)-1] - c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) + + cases := []struct { + args []interface{} + warning int + res string + }{ + {[]interface{}{"a", int64(6)}, 0, "aaaaaa"}, + {[]interface{}{"a", int64(10001)}, 1, ""}, + {[]interface{}{"毅", int64(6)}, 0, "毅毅毅毅毅毅"}, + {[]interface{}{"毅", int64(334)}, 2, ""}, + } + + for _, t := range cases { + input := chunk.NewChunkWithCapacity(colTypes, 10) + input.AppendString(0, t.args[0].(string)) + input.AppendInt64(1, t.args[1].(int64)) + + res, isNull, err := repeat.evalString(input.GetRow(0)) + c.Assert(res, Equals, t.res) + c.Assert(err, IsNil) + if t.warning == 0 { + c.Assert(isNull, IsFalse) + } else { + c.Assert(isNull, IsTrue) + c.Assert(err, IsNil) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, t.warning) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) + } + } } func (s *testEvaluatorSuite) TestLower(c *C) {