Skip to content

Commit

Permalink
feat: support setting field comments (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
iTanken authored Oct 11, 2024
1 parent 94b32a6 commit dd0e76a
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 0 deletions.
105 changes: 105 additions & 0 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,58 @@ func (m Migrator) GetTables() (tableList []string, err error) {
return tableList, m.DB.Raw("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_CATALOG = ?", m.CurrentDatabase()).Scan(&tableList).Error
}

func (m Migrator) CreateTable(values ...interface{}) (err error) {
if err = m.Migrator.CreateTable(values...); err != nil {
return
}
for _, value := range m.ReorderModels(values, false) {
if err = m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
if stmt.Schema == nil {
return
}
for _, fieldName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[fieldName]
if field.Comment == "" {
continue
}
if err = m.setColumnComment(stmt, field, true); err != nil {
return
}
}
return
}); err != nil {
return
}
}
return
}

func (m Migrator) setColumnComment(stmt *gorm.Statement, field *schema.Field, add bool) error {
schemaName := m.getTableSchemaName(stmt.Schema)
// add field comment
if add {
return m.DB.Exec(
"EXEC sp_addextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
field.Comment, schemaName, stmt.Table, field.DBName,
).Error
}
// update field comment
return m.DB.Exec(
"EXEC sp_updateextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
field.Comment, schemaName, stmt.Table, field.DBName,
).Error
}

func (m Migrator) getTableSchemaName(schema *schema.Schema) string {
// return the schema name if it is explicitly provided in the table name
// otherwise return default schema name
schemaName := getTableSchemaName(schema)
if schemaName == "" {
schemaName = m.DefaultSchema()
}
return schemaName
}

func getTableSchemaName(schema *schema.Schema) string {
// return the schema name if it is explicitly provided in the table name
// otherwise return a sql wildcard -> use any table_schema
Expand Down Expand Up @@ -141,6 +193,26 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
).Error
}

func (m Migrator) AddColumn(value interface{}, name string) error {
if err := m.Migrator.AddColumn(value, name); err != nil {
return err
}

return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
if field.Comment == "" {
return
}
if err = m.setColumnComment(stmt, field, true); err != nil {
return
}
}
}
return
})
}

func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
Expand Down Expand Up @@ -200,6 +272,39 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
})
}

func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (description string) {
queryTx := m.DB
if m.DB.DryRun {
queryTx = m.DB.Session(&gorm.Session{})
queryTx.DryRun = false
}
var comment sql.NullString
queryTx.Raw("SELECT value FROM ?.sys.fn_listextendedproperty('MS_Description', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?)",
gorm.Expr(m.CurrentDatabase()), m.getTableSchemaName(stmt.Schema), stmt.Table, fieldDBName).Scan(&comment)
if comment.Valid {
description = comment.String
}
return
}

func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil {
return err
}

return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
description := m.GetColumnComment(stmt, field.DBName)
if field.Comment != description {
if description == "" {
err = m.setColumnComment(stmt, field, true)
} else {
err = m.setColumnComment(stmt, field, false)
}
}
return
})
}

var defaultValueTrimRegexp = regexp.MustCompile("^\\('?([^']*)'?\\)$")

// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
Expand Down
60 changes: 60 additions & 0 deletions migrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,63 @@ func testGetMigrateColumns(db *gorm.DB, dst interface{}) (columnsWithDefault, co
}
return
}

type TestTableFieldComment struct {
ID string `gorm:"column:id;primaryKey"`
Name string `gorm:"column:name;comment:姓名"`
Age uint `gorm:"column:age;comment:年龄"`
}

func (*TestTableFieldComment) TableName() string { return "test_table_field_comment" }

type TestTableFieldCommentUpdate struct {
ID string `gorm:"column:id;primaryKey"`
Name string `gorm:"column:name;comment:姓名"`
Age uint `gorm:"column:age;comment:周岁"`
Birthday *time.Time `gorm:"column:birthday;comment:生日"`
}

func (*TestTableFieldCommentUpdate) TableName() string { return "test_table_field_comment" }

func TestMigrator_MigrateColumnComment(t *testing.T) {
db, err := gorm.Open(sqlserver.Open(sqlserverDSN))
if err != nil {
t.Error(err)
}
migrator := db.Debug().Migrator()

tableModel := new(TestTableFieldComment)
defer func() {
if err = migrator.DropTable(tableModel); err != nil {
t.Errorf("couldn't drop table %q, got error: %v", tableModel.TableName(), err)
}
}()

if err = migrator.AutoMigrate(tableModel); err != nil {
t.Fatal(err)
}
tableModelUpdate := new(TestTableFieldCommentUpdate)
if err = migrator.AutoMigrate(tableModelUpdate); err != nil {
t.Error(err)
}

if m, ok := migrator.(sqlserver.Migrator); ok {
stmt := db.Model(tableModelUpdate).Find(nil).Statement
if stmt == nil || stmt.Schema == nil {
t.Fatal("expected Statement.Schema, got nil")
}

wantComments := []string{"", "姓名", "周岁", "生日"}
gotComments := make([]string, len(stmt.Schema.DBNames))

for i, fieldDBName := range stmt.Schema.DBNames {
comment := m.GetColumnComment(stmt, fieldDBName)
gotComments[i] = comment
}

if !reflect.DeepEqual(wantComments, gotComments) {
t.Fatalf("expected comments %#v, got %#v", wantComments, gotComments)
}
t.Logf("got comments: %#v", gotComments)
}
}

0 comments on commit dd0e76a

Please sign in to comment.