diff --git a/array.go b/array.go index 39c8f7e2..3769cec8 100644 --- a/array.go +++ b/array.go @@ -20,6 +20,7 @@ var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() // // For example: // db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) +// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, []int{235, 401}) // go1.9+ // // var x []sql.NullInt64 // db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) diff --git a/conn_go19.go b/conn_go19.go new file mode 100644 index 00000000..15f31be4 --- /dev/null +++ b/conn_go19.go @@ -0,0 +1,30 @@ +// +build go1.9 + +package pq + +import ( + "database/sql/driver" + "reflect" +) + +var _ driver.NamedValueChecker = (*conn)(nil) + +func (c *conn) CheckNamedValue(nv *driver.NamedValue) error { + if _, ok := nv.Value.(driver.Valuer); ok { + // Ignore Valuer, for backward compatiblity with pq.Array() + return driver.ErrSkip + } + + // Ignoring []byte / []uint8 + if _, ok := nv.Value.([]uint8); ok { + return driver.ErrSkip + } + + if k := reflect.ValueOf(nv.Value).Kind(); k == reflect.Array || k == reflect.Slice { + var err error + nv.Value, err = Array(nv.Value).Value() + return err + } + + return driver.ErrSkip +} diff --git a/conn_go19_test.go b/conn_go19_test.go new file mode 100644 index 00000000..abbee91b --- /dev/null +++ b/conn_go19_test.go @@ -0,0 +1,71 @@ +// +build go1.9 + +package pq + +import ( + "reflect" + "testing" +) + +func TestArrayArg(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + for _, tc := range []struct { + name string + want interface{} + }{ + { + name: "array-value", + want: [...]int64{245, 231}, + }, + { + name: "slice-value", + want: []int64{245, 231}, + }, + { + name: "array-pointer", + want: &[...]int64{245, 231}, + }, + { + name: "slice-pointer", + want: &[]int64{245, 231}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + r, err := db.Query("SELECT $1::int[]", Array(tc.want)) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(r.Err()) + } + t.Fatal("expected row") + } + + defer func() { + if r.Next() { + t.Fatal("unexpected row") + } + }() + + rt := reflect.TypeOf(tc.want) + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + } + + got := reflect.New(rt).Interface() + if err := r.Scan(Array(got)); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(tc.want, got) { + t.Errorf("got %v, want %v", got, tc.want) + } + }) + } + +}