diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 5731e2abee065..b77db9201cb97 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" @@ -529,8 +530,11 @@ func (b *builtinCastIntAsStringSig) evalString(row chunk.Row) (res string, isNul } res = strconv.FormatUint(uVal, 10) } - res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx) - return res, false, errors.Trace(err) + res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx, false) + if err != nil { + return res, false, err + } + return padZeroForBinaryType(res, b.tp, b.ctx) } type builtinCastIntAsTimeSig struct { @@ -806,11 +810,11 @@ func (b *builtinCastRealAsStringSig) evalString(row chunk.Row) (res string, isNu // If we strconv.FormatFloat the value with 64bits, the result is incorrect! bits = 32 } - res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, b.ctx.GetSessionVars().StmtCtx) + res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, b.ctx.GetSessionVars().StmtCtx, false) if err != nil { return res, false, err } - return res, isNull, errors.Trace(err) + return padZeroForBinaryType(res, b.tp, b.ctx) } type builtinCastRealAsTimeSig struct { @@ -940,8 +944,11 @@ func (b *builtinCastDecimalAsStringSig) evalString(row chunk.Row) (res string, i return res, isNull, errors.Trace(err) } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc) - return res, false, errors.Trace(err) + res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc, false) + if err != nil { + return res, false, err + } + return padZeroForBinaryType(res, b.tp, b.ctx) } type builtinCastDecimalAsRealSig struct { @@ -1036,8 +1043,11 @@ func (b *builtinCastStringAsStringSig) evalString(row chunk.Row) (res string, is return res, isNull, errors.Trace(err) } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, sc) - return res, false, errors.Trace(err) + res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, sc, false) + if err != nil { + return res, false, err + } + return padZeroForBinaryType(res, b.tp, b.ctx) } type builtinCastStringAsIntSig struct { @@ -1334,8 +1344,11 @@ func (b *builtinCastTimeAsStringSig) evalString(row chunk.Row) (res string, isNu return res, isNull, errors.Trace(err) } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc) - return res, false, errors.Trace(err) + res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc, false) + if err != nil { + return res, false, err + } + return padZeroForBinaryType(res, b.tp, b.ctx) } type builtinCastTimeAsDurationSig struct { @@ -1458,8 +1471,30 @@ func (b *builtinCastDurationAsStringSig) evalString(row chunk.Row) (res string, return res, isNull, errors.Trace(err) } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc) - return res, false, errors.Trace(err) + res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc, false) + if err != nil { + return res, false, err + } + return padZeroForBinaryType(res, b.tp, b.ctx) +} + +func padZeroForBinaryType(s string, tp *types.FieldType, ctx sessionctx.Context) (string, bool, error) { + flen := tp.Flen + if tp.Tp == mysql.TypeString && types.IsBinaryStr(tp) && len(s) < flen { + sc := ctx.GetSessionVars().StmtCtx + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return "", false, err + } + if uint64(flen) > maxAllowedPacket { + sc.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("cast_as_binary", maxAllowedPacket)) + return "", true, nil + } + padding := make([]byte, flen-len(s)) + s = string(append([]byte(s), padding...)) + } + return s, false, nil } type builtinCastDurationAsTimeSig struct { diff --git a/expression/builtin_cast_test.go b/expression/builtin_cast_test.go index 1e4ef80c7a0e4..05207acfde41a 100644 --- a/expression/builtin_cast_test.go +++ b/expression/builtin_cast_test.go @@ -76,6 +76,17 @@ func (s *testEvaluatorSuite) TestCast(c *C) { c.Assert(len(res.GetString()), Equals, 5) c.Assert(res.GetString(), Equals, string([]byte{'a', 0x00, 0x00, 0x00, 0x00})) + // cast(str as binary(N)), N > len([]byte(str)). + // cast("a" as binary(4294967295)) + tp.Flen = 4294967295 + f = BuildCastFunction(ctx, &Constant{Value: types.NewDatum("a"), RetType: types.NewFieldType(mysql.TypeString)}, tp) + res, err = f.Eval(chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(res.IsNull(), IsTrue) + warnings := sc.GetWarnings() + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err)) + origSc := sc sc.InSelectStmt = true sc.OverflowAsWarning = true @@ -93,8 +104,8 @@ func (s *testEvaluatorSuite) TestCast(c *C) { c.Assert(err, IsNil) c.Assert(res.GetUint64() == math.MaxUint64, IsTrue) - warnings := sc.GetWarnings() - lastWarn := warnings[len(warnings)-1] + warnings = sc.GetWarnings() + lastWarn = warnings[len(warnings)-1] c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err)) originFlag := tp1.Flag diff --git a/types/datum.go b/types/datum.go index 33558c416cfdb..ed340e00aaf33 100644 --- a/types/datum.go +++ b/types/datum.go @@ -804,7 +804,7 @@ func (d *Datum) convertToString(sc *stmtctx.StatementContext, target *FieldType) default: return invalidConv(d, target.Tp) } - s, err := ProduceStrWithSpecifiedTp(s, target, sc) + s, err := ProduceStrWithSpecifiedTp(s, target, sc, true) ret.SetString(s) if target.Charset == charset.CharsetBin { ret.k = KindBytes @@ -812,8 +812,9 @@ func (d *Datum) convertToString(sc *stmtctx.StatementContext, target *FieldType) return ret, errors.Trace(err) } -// ProduceStrWithSpecifiedTp produces a new string according to `flen` and `chs`. -func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementContext) (_ string, err error) { +// ProduceStrWithSpecifiedTp produces a new string according to `flen` and `chs`. Param `padZero` indicates +// whether we should pad `\0` for `binary(flen)` type. +func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementContext, padZero bool) (_ string, err error) { flen, chs := tp.Flen, tp.Charset if flen >= 0 { // Flen is the rune length, not binary length, for UTF8 charset, we need to calculate the @@ -842,7 +843,7 @@ func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementCon } else if len(s) > flen { err = ErrDataTooLong.GenWithStack("Data Too Long, field len %d, data len %d", flen, len(s)) s = truncateStr(s, flen) - } else if tp.Tp == mysql.TypeString && IsBinaryStr(tp) && len(s) < flen { + } else if tp.Tp == mysql.TypeString && IsBinaryStr(tp) && len(s) < flen && padZero { padding := make([]byte, flen-len(s)) s = string(append([]byte(s), padding...)) }