Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: add max_allowed_packet check in concat/concat_ws #11137

Merged
merged 15 commits into from
Jul 16, 2019
55 changes: 43 additions & 12 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,17 +272,26 @@ func (c *concatFunctionClass) getFunction(ctx sessionctx.Context, args []Express
if bf.tp.Flen >= mysql.MaxBlobWidth {
bf.tp.Flen = mysql.MaxBlobWidth
}
sig := &builtinConcatSig{bf}

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
}

sig := &builtinConcatSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinConcatSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinConcatSig) Clone() builtinFunc {
newSig := &builtinConcatSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -295,6 +304,10 @@ func (b *builtinConcatSig) evalString(row chunk.Row) (d string, isNull bool, err
if isNull || err != nil {
return d, isNull, err
}
if uint64(len(s)+len(d)) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat", b.maxAllowedPacket))
return "", true, nil
}
s = append(s, []byte(d)...)
}
return string(s), false, nil
Expand Down Expand Up @@ -337,17 +350,25 @@ func (c *concatWSFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
bf.tp.Flen = mysql.MaxBlobWidth
}

sig := &builtinConcatWSSig{bf}
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
}

sig := &builtinConcatWSSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinConcatWSSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinConcatWSSig) Clone() builtinFunc {
newSig := &builtinConcatWSSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -357,25 +378,35 @@ func (b *builtinConcatWSSig) evalString(row chunk.Row) (string, bool, error) {
args := b.getArgs()
strs := make([]string, 0, len(args))
var sep string
for i, arg := range args {
val, isNull, err := arg.EvalString(b.ctx, row)
var targetLength int

N := len(args)
if N > 0 {
val, isNull, err := args[0].EvalString(b.ctx, row)
if err != nil || isNull {
// If the separator is NULL, the result is NULL.
return val, isNull, err
}
sep = val
}
for i := 1; i < N; i++ {
val, isNull, err := args[i].EvalString(b.ctx, row)
if err != nil {
return val, isNull, err
}

if isNull {
// If the separator is NULL, the result is NULL.
if i == 0 {
return val, isNull, nil
}
// CONCAT_WS() does not skip empty strings. However,
// it does skip any NULL values after the separator argument.
continue
}

if i == 0 {
sep = val
continue
targetLength += len(val)
if i > 1 {
targetLength += len(sep)
}
SunRunAway marked this conversation as resolved.
Show resolved Hide resolved
if uint64(targetLength) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat_ws", b.maxAllowedPacket))
return "", true, nil
}
strs = append(strs, val)
}
Expand Down
91 changes: 91 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,50 @@ func (s *testEvaluatorSuite) TestConcat(c *C) {
}
}

func (s *testEvaluatorSuite) TestConcatSig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000}
args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
concat := &builtinConcatSig{base, 5}

cases := []struct {
args []interface{}
warnings int
res string
}{
{[]interface{}{"a", "b"}, 0, "ab"},
{[]interface{}{"aaa", "bbb"}, 1, ""},
{[]interface{}{"中", "a"}, 0, "中a"},
{[]interface{}{"中文", "a"}, 2, ""},
}

for _, t := range cases {
input := chunk.NewChunkWithCapacity(colTypes, 10)
input.AppendString(0, t.args[0].(string))
input.AppendString(1, t.args[1].(string))

res, isNull, err := concat.evalString(input.GetRow(0))
c.Assert(res, Equals, t.res)
c.Assert(err, IsNil)
if t.warnings == 0 {
c.Assert(isNull, IsFalse)
} else {
c.Assert(isNull, IsTrue)
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(warnings, HasLen, t.warnings)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}
}
}

func (s *testEvaluatorSuite) TestConcatWS(c *C) {
defer testleak.AfterTest(c)()
cases := []struct {
Expand Down Expand Up @@ -246,6 +290,53 @@ func (s *testEvaluatorSuite) TestConcatWS(c *C) {
c.Assert(err, IsNil)
}

func (s *testEvaluatorSuite) TestConcatWSSig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000}
args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
&Column{Index: 2, RetType: colTypes[2]},
}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
concat := &builtinConcatWSSig{base, 6}

cases := []struct {
args []interface{}
warnings int
res string
}{
{[]interface{}{",", "a", "b"}, 0, "a,b"},
{[]interface{}{",", "aaa", "bbb"}, 1, ""},
{[]interface{}{",", "中", "a"}, 0, "中,a"},
{[]interface{}{",", "中文", "a"}, 2, ""},
}

for _, t := range cases {
input := chunk.NewChunkWithCapacity(colTypes, 10)
input.AppendString(0, t.args[0].(string))
input.AppendString(1, t.args[1].(string))
input.AppendString(2, t.args[2].(string))

res, isNull, err := concat.evalString(input.GetRow(0))
c.Assert(res, Equals, t.res)
c.Assert(err, IsNil)
if t.warnings == 0 {
c.Assert(isNull, IsFalse)
} else {
c.Assert(isNull, IsTrue)
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(warnings, HasLen, t.warnings)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}
}
}

func (s *testEvaluatorSuite) TestLeft(c *C) {
defer testleak.AfterTest(c)()
stmtCtx := s.ctx.GetSessionVars().StmtCtx
Expand Down
3 changes: 3 additions & 0 deletions util/mock/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ func NewContext() *Context {
sctx.sessionVars.MaxChunkSize = 32
sctx.sessionVars.StmtCtx.TimeZone = time.UTC
sctx.sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor()
if err := sctx.GetSessionVars().SetSystemVar(variable.MaxAllowedPacket, "67108864"); err != nil {
panic(err)
SunRunAway marked this conversation as resolved.
Show resolved Hide resolved
}
return sctx
}

Expand Down