Skip to content

Commit

Permalink
expression: return null when cast to huge binary type (#8768) (#9349)
Browse files Browse the repository at this point in the history
  • Loading branch information
eurekaka authored and zimulala committed Feb 19, 2019
1 parent e491ceb commit 5fe4c27
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 18 deletions.
59 changes: 47 additions & 12 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -529,8 +530,11 @@ func (b *builtinCastIntAsStringSig) evalString(row chunk.Row) (res string, isNul
}
res = strconv.FormatUint(uVal, 10)
}
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx)
return res, false, errors.Trace(err)
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx, false)
if err != nil {
return res, false, err
}
return padZeroForBinaryType(res, b.tp, b.ctx)
}

type builtinCastIntAsTimeSig struct {
Expand Down Expand Up @@ -806,11 +810,11 @@ func (b *builtinCastRealAsStringSig) evalString(row chunk.Row) (res string, isNu
// If we strconv.FormatFloat the value with 64bits, the result is incorrect!
bits = 32
}
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, b.ctx.GetSessionVars().StmtCtx)
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, b.ctx.GetSessionVars().StmtCtx, false)
if err != nil {
return res, false, err
}
return res, isNull, errors.Trace(err)
return padZeroForBinaryType(res, b.tp, b.ctx)
}

type builtinCastRealAsTimeSig struct {
Expand Down Expand Up @@ -940,8 +944,11 @@ func (b *builtinCastDecimalAsStringSig) evalString(row chunk.Row) (res string, i
return res, isNull, errors.Trace(err)
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc)
return res, false, errors.Trace(err)
res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc, false)
if err != nil {
return res, false, err
}
return padZeroForBinaryType(res, b.tp, b.ctx)
}

type builtinCastDecimalAsRealSig struct {
Expand Down Expand Up @@ -1036,8 +1043,11 @@ func (b *builtinCastStringAsStringSig) evalString(row chunk.Row) (res string, is
return res, isNull, errors.Trace(err)
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, sc)
return res, false, errors.Trace(err)
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, sc, false)
if err != nil {
return res, false, err
}
return padZeroForBinaryType(res, b.tp, b.ctx)
}

type builtinCastStringAsIntSig struct {
Expand Down Expand Up @@ -1334,8 +1344,11 @@ func (b *builtinCastTimeAsStringSig) evalString(row chunk.Row) (res string, isNu
return res, isNull, errors.Trace(err)
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc)
return res, false, errors.Trace(err)
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc, false)
if err != nil {
return res, false, err
}
return padZeroForBinaryType(res, b.tp, b.ctx)
}

type builtinCastTimeAsDurationSig struct {
Expand Down Expand Up @@ -1458,8 +1471,30 @@ func (b *builtinCastDurationAsStringSig) evalString(row chunk.Row) (res string,
return res, isNull, errors.Trace(err)
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc)
return res, false, errors.Trace(err)
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc, false)
if err != nil {
return res, false, err
}
return padZeroForBinaryType(res, b.tp, b.ctx)
}

func padZeroForBinaryType(s string, tp *types.FieldType, ctx sessionctx.Context) (string, bool, error) {
flen := tp.Flen
if tp.Tp == mysql.TypeString && types.IsBinaryStr(tp) && len(s) < flen {
sc := ctx.GetSessionVars().StmtCtx
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return "", false, err
}
if uint64(flen) > maxAllowedPacket {
sc.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("cast_as_binary", maxAllowedPacket))
return "", true, nil
}
padding := make([]byte, flen-len(s))
s = string(append([]byte(s), padding...))
}
return s, false, nil
}

type builtinCastDurationAsTimeSig struct {
Expand Down
15 changes: 13 additions & 2 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ func (s *testEvaluatorSuite) TestCast(c *C) {
c.Assert(len(res.GetString()), Equals, 5)
c.Assert(res.GetString(), Equals, string([]byte{'a', 0x00, 0x00, 0x00, 0x00}))

// cast(str as binary(N)), N > len([]byte(str)).
// cast("a" as binary(4294967295))
tp.Flen = 4294967295
f = BuildCastFunction(ctx, &Constant{Value: types.NewDatum("a"), RetType: types.NewFieldType(mysql.TypeString)}, tp)
res, err = f.Eval(chunk.Row{})
c.Assert(err, IsNil)
c.Assert(res.IsNull(), IsTrue)
warnings := sc.GetWarnings()
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err))

origSc := sc
sc.InSelectStmt = true
sc.OverflowAsWarning = true
Expand All @@ -93,8 +104,8 @@ func (s *testEvaluatorSuite) TestCast(c *C) {
c.Assert(err, IsNil)
c.Assert(res.GetUint64() == math.MaxUint64, IsTrue)

warnings := sc.GetWarnings()
lastWarn := warnings[len(warnings)-1]
warnings = sc.GetWarnings()
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err))

originFlag := tp1.Flag
Expand Down
9 changes: 5 additions & 4 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,16 +804,17 @@ func (d *Datum) convertToString(sc *stmtctx.StatementContext, target *FieldType)
default:
return invalidConv(d, target.Tp)
}
s, err := ProduceStrWithSpecifiedTp(s, target, sc)
s, err := ProduceStrWithSpecifiedTp(s, target, sc, true)
ret.SetString(s)
if target.Charset == charset.CharsetBin {
ret.k = KindBytes
}
return ret, errors.Trace(err)
}

// ProduceStrWithSpecifiedTp produces a new string according to `flen` and `chs`.
func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementContext) (_ string, err error) {
// ProduceStrWithSpecifiedTp produces a new string according to `flen` and `chs`. Param `padZero` indicates
// whether we should pad `\0` for `binary(flen)` type.
func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementContext, padZero bool) (_ string, err error) {
flen, chs := tp.Flen, tp.Charset
if flen >= 0 {
// Flen is the rune length, not binary length, for UTF8 charset, we need to calculate the
Expand Down Expand Up @@ -842,7 +843,7 @@ func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementCon
} else if len(s) > flen {
err = ErrDataTooLong.GenWithStack("Data Too Long, field len %d, data len %d", flen, len(s))
s = truncateStr(s, flen)
} else if tp.Tp == mysql.TypeString && IsBinaryStr(tp) && len(s) < flen {
} else if tp.Tp == mysql.TypeString && IsBinaryStr(tp) && len(s) < flen && padZero {
padding := make([]byte, flen-len(s))
s = string(append([]byte(s), padding...))
}
Expand Down

0 comments on commit 5fe4c27

Please sign in to comment.