Skip to content

Commit

Permalink
Refactor build SQL condition
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 11, 2018
1 parent 86c0479 commit 7a8c2bb
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 110 deletions.
6 changes: 4 additions & 2 deletions create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ func TestCreate(t *testing.T) {
}

var newUser User
DB.First(&newUser, user.Id)
if err := DB.First(&newUser, user.Id).Error; err != nil {
t.Errorf("No error should happen, but got %v", err)
}

if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
t.Errorf("User's PasswordHash should be saved ([]byte)")
Expand All @@ -38,7 +40,7 @@ func TestCreate(t *testing.T) {
}

if newUser.UserNum != Num(111) {
t.Errorf("User's UserNum should be saved (custom type)")
t.Errorf("User's UserNum should be saved (custom type), but got %v", newUser.UserNum)
}

if newUser.Latitude != float {
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB {
// Exec execute raw sql
func (s *DB) Exec(sql string, values ...interface{}) *DB {
scope := s.NewScope(nil)
generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true)
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
scope.Raw(generatedSQL)
return scope.Exec().db
Expand Down
3 changes: 3 additions & 0 deletions migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"os"
"reflect"
"strconv"
"testing"
"time"

Expand Down Expand Up @@ -168,6 +169,8 @@ type Num int64
func (i *Num) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
n, _ := strconv.Atoi(string(s))
*i = Num(n)
case int64:
*i = Num(s)
default:
Expand Down
2 changes: 1 addition & 1 deletion query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) {
var address AddressByZipCode
DB.First(&address, "00501")
if address.ZipCode != "00501" {
t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed")
t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode)
}
}

Expand Down
152 changes: 46 additions & 106 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
)
Expand Down Expand Up @@ -521,126 +520,67 @@ func (scope *Scope) primaryCondition(value interface{}) string {
return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value)
}

func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) {
var (
quotedTableName = scope.QuotedTableName()
quotedPrimaryKey = scope.Quote(scope.PrimaryKey())
equalSQL = "="
inSQL = "IN"
)

// If building not conditions
if !include {
equalSQL = "<>"
inSQL = "NOT IN"
}

switch value := clause["query"].(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return scope.primaryCondition(scope.AddToVars(value))
case sql.NullInt64:
return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64)
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value)
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()))
if !include && reflect.ValueOf(value).Len() == 0 {
return
}
str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL)
clause["args"] = []interface{}{value}
case string:
if isNumberRegexp.MatchString(value) {
return scope.primaryCondition(scope.AddToVars(value))
} else if value != "" {
str = fmt.Sprintf("(%v)", value)
}
case map[string]interface{}:
var sqls []string
for key, value := range value {
if value != nil {
sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value)))
} else {
sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", scope.QuotedTableName(), scope.Quote(key)))
}
return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value))
}
return strings.Join(sqls, " AND ")
case interface{}:
var sqls []string
newScope := scope.New(value)
for _, field := range newScope.Fields() {
if !field.IsIgnored && !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
}

replacements := []string{}
args := clause["args"].([]interface{})
for _, arg := range args {
var err error
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, err = scanner.Value()
replacements = append(replacements, scope.AddToVars(arg))
} else if b, ok := arg.([]byte); ok {
replacements = append(replacements, scope.AddToVars(b))
} else if values := reflect.ValueOf(arg); values.Len() > 0 {
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
if value != "" {
if !include {
if comparisonRegexp.MatchString(value) {
str = fmt.Sprintf("NOT (%v)", value)
} else {
str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value))
}
replacements = append(replacements, strings.Join(tempMarks, ","))
} else {
replacements = append(replacements, scope.AddToVars(Expr("NULL")))
str = fmt.Sprintf("(%v)", value)
}
default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, err = valuer.Value()
}

replacements = append(replacements, scope.AddToVars(arg))
}
if err != nil {
scope.Err(err)
}
}

buff := bytes.NewBuffer([]byte{})
i := 0
for pos := range str {
if str[pos] == '?' {
buff.WriteString(replacements[i])
i++
} else {
buff.WriteByte(str[pos])
}
}

str = buff.String()

return
}

func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
var primaryKey = scope.PrimaryKey()

switch value := clause["query"].(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), value)
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
if reflect.ValueOf(value).Len() > 0 {
str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(primaryKey))
clause["args"] = []interface{}{value}
} else {
return ""
}
case string:
if isNumberRegexp.MatchString(value) {
id, _ := strconv.Atoi(value)
return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), id)
} else if comparisonRegexp.MatchString(value) {
str = fmt.Sprintf("NOT (%v)", value)
} else {
str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value))
}
case map[string]interface{}:
var sqls []string
for key, value := range value {
if value != nil {
sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value)))
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value)))
} else {
sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", scope.QuotedTableName(), scope.Quote(key)))
if !include {
sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key)))
} else {
sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key)))
}
}
}
return strings.Join(sqls, " AND ")
case interface{}:
var sqls []string
var newScope = scope.New(value)
newScope := scope.New(value)
for _, field := range newScope.Fields() {
if !field.IsIgnored && !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
Expand All @@ -667,8 +607,8 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
replacements = append(replacements, scope.AddToVars(Expr("NULL")))
}
default:
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, err = scanner.Value()
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, err = valuer.Value()
}

replacements = append(replacements, scope.AddToVars(arg))
Expand All @@ -681,7 +621,6 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string

buff := bytes.NewBuffer([]byte{})
i := 0

for pos := range str {
if str[pos] == '?' {
buff.WriteString(replacements[i])
Expand All @@ -692,6 +631,7 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string
}

str = buff.String()

return
}

Expand Down Expand Up @@ -758,19 +698,19 @@ func (scope *Scope) whereSQL() (sql string) {
}

for _, clause := range scope.Search.whereConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
if sql := scope.buildCondition(clause, true); sql != "" {
andConditions = append(andConditions, sql)
}
}

for _, clause := range scope.Search.orConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
if sql := scope.buildCondition(clause, true); sql != "" {
orConditions = append(orConditions, sql)
}
}

for _, clause := range scope.Search.notConditions {
if sql := scope.buildNotCondition(clause); sql != "" {
if sql := scope.buildCondition(clause, false); sql != "" {
andConditions = append(andConditions, sql)
}
}
Expand Down Expand Up @@ -844,7 +784,7 @@ func (scope *Scope) havingSQL() string {

var andConditions []string
for _, clause := range scope.Search.havingConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
if sql := scope.buildCondition(clause, true); sql != "" {
andConditions = append(andConditions, sql)
}
}
Expand All @@ -860,7 +800,7 @@ func (scope *Scope) havingSQL() string {
func (scope *Scope) joinsSQL() string {
var joinConditions []string
for _, clause := range scope.Search.joinConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
if sql := scope.buildCondition(clause, true); sql != "" {
joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")"))
}
}
Expand Down

0 comments on commit 7a8c2bb

Please sign in to comment.