diff --git a/magic.go b/magic.go index f5bb5c7..619a8d3 100644 --- a/magic.go +++ b/magic.go @@ -1,7 +1,6 @@ package mogi import ( - "database/sql/driver" "fmt" "log" "strconv" @@ -74,36 +73,6 @@ func transmogrify(v interface{}) interface{} { return nil } -// convert args to their 64-bit versions -// for easy comparisons -func unify(v interface{}) interface{} { - switch x := v.(type) { - case int: - return int64(x) - case int32: - return int64(x) - case float32: - return float64(x) - case string: - return []byte(x) - } - return v -} - -func unifyValues(arr []driver.Value) []driver.Value { - for i, v := range arr { - arr[i] = unify(v) - } - return arr -} - -func unifyArray(arr []interface{}) []interface{} { - for i, v := range arr { - arr[i] = unify(v) - } - return arr -} - func extractColumnName(nse *sqlparser.NonStarExpr) string { if nse.As != nil { return string(nse.As) diff --git a/select_test.go b/select_test.go index 9024187..590de5f 100644 --- a/select_test.go +++ b/select_test.go @@ -59,6 +59,12 @@ func TestSelectWhere(t *testing.T) { mogi.Select().From("beer").Where("pct", 5).StubCSV(beerCSV) runBeerSelectQuery(t, db) + // where with weird type + type coolInt int + mogi.Reset() + mogi.Select().From("beer").Where("pct", coolInt(5)).StubCSV(beerCSV) + runBeerSelectQuery(t, db) + // wrong where mogi.Reset() mogi.Select().From("beer").Where("pct", 98).StubCSV(beerCSV) diff --git a/stub.go b/stub.go index 069af42..5e0c011 100644 --- a/stub.go +++ b/stub.go @@ -13,6 +13,10 @@ type Stub struct { resolve func(input) } +type subquery struct { + chain condchain +} + // Select starts a new stub for SELECT statements. // You can filter out which columns to use this stub for. // If you don't pass any columns, it will stub all SELECT queries. @@ -91,6 +95,10 @@ func (s *Stub) StubError(err error) { addStub(s) } +func (s *Stub) Subquery() subquery { + return subquery{chain: s.chain} +} + func (s *Stub) matches(in input) bool { return s.chain.matches(in) } diff --git a/unify.go b/unify.go new file mode 100644 index 0000000..3446d1c --- /dev/null +++ b/unify.go @@ -0,0 +1,103 @@ +package mogi + +import ( + "database/sql/driver" + "reflect" + "time" +) + +/* +int64 +float64 +bool +[]byte +string [*] everywhere except from Rows.Next. +time.Time +*/ + +// unify converts values to fit driver.Value, +// except []byte which is converted to string. +func unify(v interface{}) interface{} { + // happy path + switch x := v.(type) { + case nil: + return x + case bool: + return x + + // int64 + case int64: + return x + case int: + return int64(x) + case int32: + return int64(x) + case int16: + return int64(x) + case int8: + return int64(x) + case byte: + return int64(x) + + // float64 + case float32: + return float64(x) + + // string + case string: + return x + case []byte: + return string(x) + + // time.Time + case time.Time: + return x + case *time.Time: + if x == nil { + return nil + } + return *x + } + + // sad path + rv := reflect.ValueOf(v) + return reflectUnify(rv) +} + +func reflectUnify(rv reflect.Value) interface{} { + switch rv.Kind() { + case reflect.Ptr: + if rv.IsNil() { + return nil + } + return reflectUnify(rv.Elem()) + case reflect.Bool: + return rv.Bool() + case reflect.Int64, reflect.Int, reflect.Int32, reflect.Int16, reflect.Int8: + return rv.Int() + case reflect.Float64, reflect.Float32: + return rv.Float() + case reflect.String: + return rv.String() + case reflect.Slice: + if rv.Elem().Kind() == reflect.Int8 { + return string(rv.Bytes()) + } + } + + panic("couldn't unify value of type " + rv.Type().Name()) +} + +func unifyValues(values []driver.Value) []driver.Value { + for i, v := range values { + values[i] = unify(v) + } + return values +} + +func unifyInterfaces(slice []interface{}) []interface{} { + for i, v := range slice { + slice[i] = unify(v) + } + return slice +} diff --git a/where.go b/where.go index 9608a6e..04e3a95 100644 --- a/where.go +++ b/where.go @@ -14,7 +14,7 @@ type whereCond struct { func newWhereCond(col string, v []interface{}) whereCond { return whereCond{ col: strings.ToLower(col), - v: unifyArray(v), + v: unifyInterfaces(v), } } @@ -49,7 +49,7 @@ type whereOpCond struct { func newWhereOpCond(col string, v []interface{}, op string) whereOpCond { return whereOpCond{ col: strings.ToLower(col), - v: unifyArray(v), + v: unifyInterfaces(v), op: strings.ToLower(op), } }