diff --git a/session_cols_test.go b/session_cols_test.go index 5f5954c72..96cb1620e 100644 --- a/session_cols_test.go +++ b/session_cols_test.go @@ -7,21 +7,38 @@ package xorm import ( "testing" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/builder" + "xorm.io/core" ) func TestSetExpr(t *testing.T) { assert.NoError(t, prepareEngine()) + type UserExprIssue struct { + Id int64 + Title string + } + + assert.NoError(t, testEngine.Sync2(new(UserExprIssue))) + + var issue = UserExprIssue{ + Title: "my issue", + } + cnt, err := testEngine.Insert(&issue) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.EqualValues(t, 1, issue.Id) + type UserExpr struct { - Id int64 - Show bool + Id int64 + IssueId int64 `xorm:"index"` + Show bool } assert.NoError(t, testEngine.Sync2(new(UserExpr))) - cnt, err := testEngine.Insert(&UserExpr{ + cnt, err = testEngine.Insert(&UserExpr{ Show: true, }) assert.NoError(t, err) @@ -34,6 +51,16 @@ func TestSetExpr(t *testing.T) { cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr)) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + + tableName := testEngine.TableName(new(UserExprIssue), true) + cnt, err = testEngine.SetExpr("issue_id", + builder.Select("id"). + From(tableName). + Where(builder.Eq{"id": issue.Id})). + ID(1). + Update(new(UserExpr)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) } func TestCols(t *testing.T) { diff --git a/session_update.go b/session_update.go index 402470e5a..c5c65a452 100644 --- a/session_update.go +++ b/session_update.go @@ -245,7 +245,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if err != nil { return 0, err } - colNames = append(colNames, session.engine.Quote(colName)+" = "+subQuery) + colNames = append(colNames, session.engine.Quote(colName)+" = ("+subQuery+")") args = append(args, subArgs...) } } diff --git a/statement_args.go b/statement_args.go index c6168db1f..5353ae1ad 100644 --- a/statement_args.go +++ b/statement_args.go @@ -17,9 +17,15 @@ func writeArg(w *builder.BytesWriter, arg interface{}) error { return err } case *builder.Builder: + if _, err := w.WriteString("("); err != nil { + return err + } if err := argv.WriteTo(w); err != nil { return err } + if _, err := w.WriteString(")"); err != nil { + return err + } default: if _, err := w.WriteString(fmt.Sprintf("%v", argv)); err != nil { return err diff --git a/statement_exprparam.go b/statement_exprparam.go index a72f0aeac..0cddca024 100644 --- a/statement_exprparam.go +++ b/statement_exprparam.go @@ -60,9 +60,15 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { for _, expr := range exprs.args { switch arg := expr.(type) { case *builder.Builder: + if _, err := w.WriteString("("); err != nil { + return err + } if err := arg.WriteTo(w); err != nil { return err } + if _, err := w.WriteString(")"); err != nil { + return err + } default: if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil { return err @@ -83,9 +89,15 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { switch arg := exprs.args[i].(type) { case *builder.Builder: + if _, err := w.WriteString("("); err != nil { + return err + } if err := arg.WriteTo(w); err != nil { return err } + if _, err := w.WriteString("("); err != nil { + return err + } default: w.Append(exprs.args[i]) }