Skip to content
This repository has been archived by the owner on Sep 7, 2021. It is now read-only.
This repository is currently being migrated. It's locked while the migration is in progress.

Commit

Permalink
Add insert select where support (#1401)
Browse files Browse the repository at this point in the history
  • Loading branch information
lunny committed Aug 22, 2019
1 parent b78ac8c commit 17592d9
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 21 deletions.
114 changes: 93 additions & 21 deletions session_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"strconv"
"strings"

"xorm.io/builder"
"xorm.io/core"
)

Expand Down Expand Up @@ -345,7 +346,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
for _, v := range exprColumns {
// remove the expr columns
for i, colName := range colNames {
if colName == v.colName {
if colName == strings.Trim(v.colName, "`") {
colNames = append(colNames[:i], colNames[i+1:]...)
args = append(args[:i], args[i+1:]...)
}
Expand All @@ -371,12 +372,30 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
}

if len(colPlaces) > 0 {
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
session.engine.Quote(tableName),
quoteColumns(colNames, session.engine.Quote, ","),
output,
colPlaces)
if session.statement.cond.IsValid() {
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
return 0, err
}

sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s SELECT %v FROM %v WHERE %v",
session.engine.Quote(tableName),
quoteColumns(colNames, session.engine.Quote, ","),
output,
colPlaces,
session.engine.Quote(tableName),
condSQL,
)
args = append(args, condArgs...)
} else {
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
session.engine.Quote(tableName),
quoteColumns(colNames, session.engine.Quote, ","),
output,
colPlaces)
}
} else {
if session.engine.dialect.DBType() == core.MYSQL {
sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName))
Expand Down Expand Up @@ -663,26 +682,52 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
return 0, ErrParamsType
}

tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}

var columns = make([]string, 0, len(m))
for k := range m {
columns = append(columns, k)
}
sort.Strings(columns)

qm := strings.Repeat("?,", len(columns))
qm = "(" + qm[:len(qm)-1] + ")"

tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}

var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}

// insert expr columns, override if exists
exprColumns := session.statement.getExpr()
for _, col := range exprColumns {
columns = append(columns, strings.Trim(col.colName, "`"))
qm = qm + col.expr + ","
}

qm = qm[:len(qm)-1]

var sql string

if session.statement.cond.IsValid() {
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
return 0, err
}
sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
session.engine.Quote(tableName),
strings.Join(columns, "`,`"),
qm,
session.engine.Quote(tableName),
condSQL,
)
args = append(args, condArgs...)
} else {
sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
}

if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
Expand All @@ -703,24 +748,51 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
return 0, ErrParamsType
}

tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}

var columns = make([]string, 0, len(m))
for k := range m {
columns = append(columns, k)
}
sort.Strings(columns)

var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}

qm := strings.Repeat("?,", len(columns))
qm = "(" + qm[:len(qm)-1] + ")"

tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
// insert expr columns, override if exists
exprColumns := session.statement.getExpr()
for _, col := range exprColumns {
columns = append(columns, strings.Trim(col.colName, "`"))
qm = qm + col.expr + ","
}

var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
qm = qm[:len(qm)-1]

var sql string

if session.statement.cond.IsValid() {
qm = "(" + qm[:len(qm)-1] + ")"
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
return 0, err
}
sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
session.engine.Quote(tableName),
strings.Join(columns, "`,`"),
qm,
session.engine.Quote(tableName),
condSQL,
)
args = append(args, condArgs...)
} else {
sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
}

if err := session.cacheInsert(tableName); err != nil {
Expand Down
59 changes: 59 additions & 0 deletions session_insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -834,3 +834,62 @@ func TestInsertMap(t *testing.T) {
assert.EqualValues(t, 10, ims[3].Height)
assert.EqualValues(t, "lunny", ims[3].Name)
}

/*INSERT INTO `issue` (`repo_id`, `poster_id`, ... ,`name`, `content`, ... ,`index`)
SELECT $1, $2, ..., $14, $15, ..., MAX(`index`) + 1 FROM `issue` WHERE `repo_id` = $1;
*/
func TestInsertWhere(t *testing.T) {
type InsertWhere struct {
Id int64
Index int `xorm:"unique(s) notnull"`
RepoId int64 `xorm:"unique(s)"`
Width uint32
Height uint32
Name string
}

assert.NoError(t, prepareEngine())
assertSync(t, new(InsertWhere))

var i = InsertWhere{
RepoId: 1,
Width: 10,
Height: 20,
Name: "trest",
}

inserted, err := testEngine.SetExpr("`index`", "coalesce(MAX(`index`),0)+1").
Where("repo_id=?", 1).
Insert(&i)
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)
assert.EqualValues(t, 1, i.Id)

var j InsertWhere
has, err := testEngine.ID(i.Id).Get(&j)
assert.NoError(t, err)
assert.True(t, has)
i.Index = 1
assert.EqualValues(t, i, j)

inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1).
SetExpr("`index`", "coalesce(MAX(`index`),0)+1").
Insert(map[string]interface{}{
"repo_id": 1,
"width": 20,
"height": 40,
"name": "trest2",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)

var j2 InsertWhere
has, err = testEngine.ID(2).Get(&j2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, j2.RepoId)
assert.EqualValues(t, 20, j2.Width)
assert.EqualValues(t, 40, j2.Height)
assert.EqualValues(t, "trest2", j2.Name)
assert.EqualValues(t, 2, j2.Index)
}

0 comments on commit 17592d9

Please sign in to comment.