Skip to content

Commit

Permalink
beef up interface{} unifier
Browse files Browse the repository at this point in the history
  • Loading branch information
guregu committed Jan 15, 2015
1 parent a993e1e commit 5531c07
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 33 deletions.
31 changes: 0 additions & 31 deletions magic.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package mogi

import (
"database/sql/driver"
"fmt"
"log"
"strconv"
Expand Down Expand Up @@ -74,36 +73,6 @@ func transmogrify(v interface{}) interface{} {
return nil
}

// convert args to their 64-bit versions
// for easy comparisons
func unify(v interface{}) interface{} {
switch x := v.(type) {
case int:
return int64(x)
case int32:
return int64(x)
case float32:
return float64(x)
case string:
return []byte(x)
}
return v
}

func unifyValues(arr []driver.Value) []driver.Value {
for i, v := range arr {
arr[i] = unify(v)
}
return arr
}

func unifyArray(arr []interface{}) []interface{} {
for i, v := range arr {
arr[i] = unify(v)
}
return arr
}

func extractColumnName(nse *sqlparser.NonStarExpr) string {
if nse.As != nil {
return string(nse.As)
Expand Down
6 changes: 6 additions & 0 deletions select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ func TestSelectWhere(t *testing.T) {
mogi.Select().From("beer").Where("pct", 5).StubCSV(beerCSV)
runBeerSelectQuery(t, db)

// where with weird type
type coolInt int
mogi.Reset()
mogi.Select().From("beer").Where("pct", coolInt(5)).StubCSV(beerCSV)
runBeerSelectQuery(t, db)

// wrong where
mogi.Reset()
mogi.Select().From("beer").Where("pct", 98).StubCSV(beerCSV)
Expand Down
8 changes: 8 additions & 0 deletions stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ type Stub struct {
resolve func(input)
}

type subquery struct {
chain condchain
}

// Select starts a new stub for SELECT statements.
// You can filter out which columns to use this stub for.
// If you don't pass any columns, it will stub all SELECT queries.
Expand Down Expand Up @@ -91,6 +95,10 @@ func (s *Stub) StubError(err error) {
addStub(s)
}

func (s *Stub) Subquery() subquery {
return subquery{chain: s.chain}
}

func (s *Stub) matches(in input) bool {
return s.chain.matches(in)
}
Expand Down
103 changes: 103 additions & 0 deletions unify.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package mogi

import (
"database/sql/driver"
"reflect"
"time"
)

/*
int64
float64
bool
[]byte
string [*] everywhere except from Rows.Next.
time.Time
*/

// unify converts values to fit driver.Value,
// except []byte which is converted to string.
func unify(v interface{}) interface{} {
// happy path
switch x := v.(type) {
case nil:
return x
case bool:
return x

// int64
case int64:
return x
case int:
return int64(x)
case int32:
return int64(x)
case int16:
return int64(x)
case int8:
return int64(x)
case byte:
return int64(x)

// float64
case float32:
return float64(x)

// string
case string:
return x
case []byte:
return string(x)

// time.Time
case time.Time:
return x
case *time.Time:
if x == nil {
return nil
}
return *x
}

// sad path
rv := reflect.ValueOf(v)
return reflectUnify(rv)
}

func reflectUnify(rv reflect.Value) interface{} {
switch rv.Kind() {
case reflect.Ptr:
if rv.IsNil() {
return nil
}
return reflectUnify(rv.Elem())
case reflect.Bool:
return rv.Bool()
case reflect.Int64, reflect.Int, reflect.Int32, reflect.Int16, reflect.Int8:
return rv.Int()
case reflect.Float64, reflect.Float32:
return rv.Float()
case reflect.String:
return rv.String()
case reflect.Slice:
if rv.Elem().Kind() == reflect.Int8 {
return string(rv.Bytes())
}
}

panic("couldn't unify value of type " + rv.Type().Name())
}

func unifyValues(values []driver.Value) []driver.Value {
for i, v := range values {
values[i] = unify(v)
}
return values
}

func unifyInterfaces(slice []interface{}) []interface{} {
for i, v := range slice {
slice[i] = unify(v)
}
return slice
}
4 changes: 2 additions & 2 deletions where.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type whereCond struct {
func newWhereCond(col string, v []interface{}) whereCond {
return whereCond{
col: strings.ToLower(col),
v: unifyArray(v),
v: unifyInterfaces(v),
}
}

Expand Down Expand Up @@ -49,7 +49,7 @@ type whereOpCond struct {
func newWhereOpCond(col string, v []interface{}, op string) whereOpCond {
return whereOpCond{
col: strings.ToLower(col),
v: unifyArray(v),
v: unifyInterfaces(v),
op: strings.ToLower(op),
}
}
Expand Down

0 comments on commit 5531c07

Please sign in to comment.