Skip to content

Commit

Permalink
Fix bit default value bug (#7249)
Browse files Browse the repository at this point in the history
  • Loading branch information
crazycs520 authored Aug 31, 2018
1 parent 341dc10 commit 8d1acc2
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 14 deletions.
9 changes: 9 additions & 0 deletions ddl/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,15 @@ func (s *testDBSuite) TestCreateTable(c *C) {
c.Assert(err, NotNil)
}

func (s *testDBSuite) TestBitDefaultValue(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("create table t_bit (c1 bit(10) default 250, c2 int);")
tk.MustExec("insert into t_bit set c2=1;")
tk.MustQuery("select bin(c1),c2 from t_bit").Check(testkit.Rows("11111010 1"))
tk.MustExec("drop table t_bit")
}

func (s *testDBSuite) TestCreateTableWithPartition(c *C) {
s.tk = testkit.NewTestKit(c, s.store)
s.tk.MustExec("use test;")
Expand Down
28 changes: 19 additions & 9 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o
if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil {
return nil, nil, errors.Trace(err)
}
col.DefaultValue = value
if err = col.SetDefaultValue(value); err != nil {
return nil, nil, errors.Trace(err)
}
removeOnUpdateNowFlag(col)
case ast.ColumnOptionOnUpdate:
// TODO: Support other time functions.
Expand Down Expand Up @@ -473,9 +475,9 @@ func setTimestampDefaultValue(c *table.Column, hasDefaultValue bool, setOnUpdate
// For timestamp Col, if is not set default value or not set null, use current timestamp.
if mysql.HasTimestampFlag(c.Flag) && mysql.HasNotNullFlag(c.Flag) {
if setOnUpdateNow {
c.DefaultValue = types.ZeroDatetimeStr
c.SetDefaultValue(types.ZeroDatetimeStr)
} else {
c.DefaultValue = strings.ToUpper(ast.CurrentTimestamp)
c.SetDefaultValue(strings.ToUpper(ast.CurrentTimestamp))
}
}
}
Expand All @@ -500,7 +502,7 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue
return nil
}

if c.DefaultValue != nil {
if c.GetDefaultValue() != nil {
if _, err := table.GetColDefaultValue(ctx, c.ToInfo()); err != nil {
return types.ErrInvalidDefault.GenByArgs(c.Name)
}
Expand All @@ -522,7 +524,7 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue
// checkPriKeyConstraint check all parts of a PRIMARY KEY must be NOT NULL
func checkPriKeyConstraint(col *table.Column, hasDefaultValue, hasNullFlag bool, outPriKeyConstraint *ast.Constraint) error {
// Primary key should not be null.
if mysql.HasPriKeyFlag(col.Flag) && hasDefaultValue && col.DefaultValue == nil {
if mysql.HasPriKeyFlag(col.Flag) && hasDefaultValue && col.GetDefaultValue() == nil {
return types.ErrInvalidDefault.GenByArgs(col.Name)
}
// Set primary key flag for outer primary key constraint.
Expand Down Expand Up @@ -1247,7 +1249,7 @@ func (d *ddl) AddColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTab
if err != nil {
return errors.Trace(err)
}
col.OriginDefaultValue = col.DefaultValue
col.OriginDefaultValue = col.GetDefaultValue()
if col.OriginDefaultValue == nil && mysql.HasNotNullFlag(col.Flag) {
zeroVal := table.GetZeroValue(col.ToInfo())
col.OriginDefaultValue, err = zeroVal.ToString()
Expand Down Expand Up @@ -1458,7 +1460,10 @@ func setDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
if err != nil {
return ErrColumnBadNull.Gen("invalid default value - %s", err)
}
col.DefaultValue = value
err = col.SetDefaultValue(value)
if err != nil {
return errors.Trace(err)
}
return errors.Trace(checkDefaultValue(ctx, col, true))
}

Expand Down Expand Up @@ -1487,7 +1492,9 @@ func setDefaultAndComment(ctx sessionctx.Context, col *table.Column, options []*
if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil {
return errors.Trace(err)
}
col.DefaultValue = value
if err = col.SetDefaultValue(value); err != nil {
return errors.Trace(err)
}
case ast.ColumnOptionComment:
err := setColumnComment(ctx, col, opt)
if err != nil {
Expand Down Expand Up @@ -1709,7 +1716,10 @@ func (d *ddl) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Alt
// Clean the NoDefaultValueFlag value.
col.Flag &= ^mysql.NoDefaultValueFlag
if len(specNewColumn.Options) == 0 {
col.DefaultValue = nil
err = col.SetDefaultValue(nil)
if err != nil {
return errors.Trace(err)
}
setNoDefaultValueFlag(col, false)
} else {
err = setDefaultValue(ctx, col, specNewColumn.Options[0])
Expand Down
5 changes: 3 additions & 2 deletions executor/show.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,8 @@ func (e *ShowExec) fetchShowCreateTable() error {
buf.WriteString(" NOT NULL")
}
if !mysql.HasNoDefaultValueFlag(col.Flag) {
switch col.DefaultValue {
defaultValue := col.GetDefaultValue()
switch defaultValue {
case nil:
if !mysql.HasNotNullFlag(col.Flag) {
if col.Tp == mysql.TypeTimestamp {
Expand All @@ -514,7 +515,7 @@ func (e *ShowExec) fetchShowCreateTable() error {
case "CURRENT_TIMESTAMP":
buf.WriteString(" DEFAULT CURRENT_TIMESTAMP")
default:
defaultValStr := fmt.Sprintf("%v", col.DefaultValue)
defaultValStr := fmt.Sprintf("%v", defaultValue)
if col.Tp == mysql.TypeBit {
defaultValBinaryLiteral := types.BinaryLiteral(defaultValStr)
buf.WriteString(fmt.Sprintf(" DEFAULT %s", defaultValBinaryLiteral.ToBitLiteralString(true)))
Expand Down
27 changes: 27 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tipb/go-tipb"
)

Expand Down Expand Up @@ -69,6 +70,7 @@ type ColumnInfo struct {
Offset int `json:"offset"`
OriginDefaultValue interface{} `json:"origin_default"`
DefaultValue interface{} `json:"default"`
DefaultValueBit []byte `json:"default_bit"`
GeneratedExprString string `json:"generated_expr_string"`
GeneratedStored bool `json:"generated_stored"`
Dependences map[string]struct{} `json:"dependences"`
Expand All @@ -88,6 +90,31 @@ func (c *ColumnInfo) IsGenerated() bool {
return len(c.GeneratedExprString) != 0
}

// SetDefaultValue sets the default value.
func (c *ColumnInfo) SetDefaultValue(value interface{}) error {
c.DefaultValue = value
if c.Tp == mysql.TypeBit {
// For mysql.TypeBit type, the default value storage format must be a string.
// Other value such as int must convert to string format first.
if v, ok := value.(string); ok {
c.DefaultValueBit = []byte(v)
return nil
}
return types.ErrInvalidDefault.GenByArgs(c.Name)
}
return nil
}

// GetDefaultValue gets the default value of the column.
// Default value use to stored in DefaultValue field, but now,
// bit type default value will store in DefaultValueBit for fix bit default value decode/encode bug.
func (c *ColumnInfo) GetDefaultValue() interface{} {
if c.Tp == mysql.TypeBit && c.DefaultValueBit != nil {
return hack.String(c.DefaultValueBit)
}
return c.DefaultValue
}

// FindColumnInfo finds ColumnInfo in cols by name.
func FindColumnInfo(cols []*ColumnInfo, name string) *ColumnInfo {
name = strings.ToLower(name)
Expand Down
4 changes: 2 additions & 2 deletions table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func NewColDesc(col *Column) *ColDesc {
}
var defaultValue interface{}
if !mysql.HasNoDefaultValueFlag(col.Flag) {
defaultValue = col.DefaultValue
defaultValue = col.GetDefaultValue()
}

extra := ""
Expand Down Expand Up @@ -310,7 +310,7 @@ func GetColOriginDefaultValue(ctx sessionctx.Context, col *model.ColumnInfo) (ty

// GetColDefaultValue gets default value of the column.
func GetColDefaultValue(ctx sessionctx.Context, col *model.ColumnInfo) (types.Datum, error) {
return getColDefaultValue(ctx, col, col.DefaultValue)
return getColDefaultValue(ctx, col, col.GetDefaultValue())
}

func getColDefaultValue(ctx sessionctx.Context, col *model.ColumnInfo, defaultVal interface{}) (types.Datum, error) {
Expand Down
2 changes: 1 addition & 1 deletion table/tables/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ func CanSkip(info *model.TableInfo, col *table.Column, value types.Datum) bool {
if col.IsPKHandleColumn(info) {
return true
}
if col.DefaultValue == nil && value.IsNull() {
if col.GetDefaultValue() == nil && value.IsNull() {
return true
}
if col.IsGenerated() && !col.GeneratedStored {
Expand Down

0 comments on commit 8d1acc2

Please sign in to comment.