diff --git a/session_insert.go b/session_insert.go index 713565661..24b328314 100644 --- a/session_insert.go +++ b/session_insert.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" + "xorm.io/builder" "xorm.io/core" ) @@ -345,7 +346,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { for _, v := range exprColumns { // remove the expr columns for i, colName := range colNames { - if colName == v.colName { + if colName == strings.Trim(v.colName, "`") { colNames = append(colNames[:i], colNames[i+1:]...) args = append(args[:i], args[i+1:]...) } @@ -371,12 +372,30 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 { output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) } + if len(colPlaces) > 0 { - sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), - output, - colPlaces) + if session.statement.cond.IsValid() { + condSQL, condArgs, err := builder.ToSQL(session.statement.cond) + if err != nil { + return 0, err + } + + sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s SELECT %v FROM %v WHERE %v", + session.engine.Quote(tableName), + quoteColumns(colNames, session.engine.Quote, ","), + output, + colPlaces, + session.engine.Quote(tableName), + condSQL, + ) + args = append(args, condArgs...) + } else { + sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)", + session.engine.Quote(tableName), + quoteColumns(colNames, session.engine.Quote, ","), + output, + colPlaces) + } } else { if session.engine.dialect.DBType() == core.MYSQL { sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName)) @@ -663,6 +682,11 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err return 0, ErrParamsType } + tableName := session.statement.TableName() + if len(tableName) <= 0 { + return 0, ErrTableNotFound + } + var columns = make([]string, 0, len(m)) for k := range m { columns = append(columns, k) @@ -670,19 +694,40 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err sort.Strings(columns) qm := strings.Repeat("?,", len(columns)) - qm = "(" + qm[:len(qm)-1] + ")" - - tableName := session.statement.TableName() - if len(tableName) <= 0 { - return 0, ErrTableNotFound - } - var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm) var args = make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } + // insert expr columns, override if exists + exprColumns := session.statement.getExpr() + for _, col := range exprColumns { + columns = append(columns, strings.Trim(col.colName, "`")) + qm = qm + col.expr + "," + } + + qm = qm[:len(qm)-1] + + var sql string + + if session.statement.cond.IsValid() { + condSQL, condArgs, err := builder.ToSQL(session.statement.cond) + if err != nil { + return 0, err + } + sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s", + session.engine.Quote(tableName), + strings.Join(columns, "`,`"), + qm, + session.engine.Quote(tableName), + condSQL, + ) + args = append(args, condArgs...) + } else { + sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm) + } + if err := session.cacheInsert(tableName); err != nil { return 0, err } @@ -703,24 +748,51 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { return 0, ErrParamsType } + tableName := session.statement.TableName() + if len(tableName) <= 0 { + return 0, ErrTableNotFound + } + var columns = make([]string, 0, len(m)) for k := range m { columns = append(columns, k) } sort.Strings(columns) + var args = make([]interface{}, 0, len(m)) + for _, colName := range columns { + args = append(args, m[colName]) + } + qm := strings.Repeat("?,", len(columns)) - qm = "(" + qm[:len(qm)-1] + ")" - tableName := session.statement.TableName() - if len(tableName) <= 0 { - return 0, ErrTableNotFound + // insert expr columns, override if exists + exprColumns := session.statement.getExpr() + for _, col := range exprColumns { + columns = append(columns, strings.Trim(col.colName, "`")) + qm = qm + col.expr + "," } - var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm) - var args = make([]interface{}, 0, len(m)) - for _, colName := range columns { - args = append(args, m[colName]) + qm = qm[:len(qm)-1] + + var sql string + + if session.statement.cond.IsValid() { + qm = "(" + qm[:len(qm)-1] + ")" + condSQL, condArgs, err := builder.ToSQL(session.statement.cond) + if err != nil { + return 0, err + } + sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s", + session.engine.Quote(tableName), + strings.Join(columns, "`,`"), + qm, + session.engine.Quote(tableName), + condSQL, + ) + args = append(args, condArgs...) + } else { + sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm) } if err := session.cacheInsert(tableName); err != nil { diff --git a/session_insert_test.go b/session_insert_test.go index 8e7ffa990..daf08e7f0 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -834,3 +834,62 @@ func TestInsertMap(t *testing.T) { assert.EqualValues(t, 10, ims[3].Height) assert.EqualValues(t, "lunny", ims[3].Name) } + +/*INSERT INTO `issue` (`repo_id`, `poster_id`, ... ,`name`, `content`, ... ,`index`) +SELECT $1, $2, ..., $14, $15, ..., MAX(`index`) + 1 FROM `issue` WHERE `repo_id` = $1; +*/ +func TestInsertWhere(t *testing.T) { + type InsertWhere struct { + Id int64 + Index int `xorm:"unique(s) notnull"` + RepoId int64 `xorm:"unique(s)"` + Width uint32 + Height uint32 + Name string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(InsertWhere)) + + var i = InsertWhere{ + RepoId: 1, + Width: 10, + Height: 20, + Name: "trest", + } + + inserted, err := testEngine.SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). + Where("repo_id=?", 1). + Insert(&i) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + assert.EqualValues(t, 1, i.Id) + + var j InsertWhere + has, err := testEngine.ID(i.Id).Get(&j) + assert.NoError(t, err) + assert.True(t, has) + i.Index = 1 + assert.EqualValues(t, i, j) + + inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). + Insert(map[string]interface{}{ + "repo_id": 1, + "width": 20, + "height": 40, + "name": "trest2", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + + var j2 InsertWhere + has, err = testEngine.ID(2).Get(&j2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, j2.RepoId) + assert.EqualValues(t, 20, j2.Width) + assert.EqualValues(t, 40, j2.Height) + assert.EqualValues(t, "trest2", j2.Name) + assert.EqualValues(t, 2, j2.Index) +}