diff --git a/scanner.go b/scanner.go index 1c7ebbd..e8cead9 100644 --- a/scanner.go +++ b/scanner.go @@ -165,6 +165,8 @@ func initFieldTag(sliceItem reflect.Value, fieldTagMap *map[string]reflect.Value } } +var sqlScannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + func structPointers(sliceItem reflect.Value, cols []string, strict bool) []interface{} { pointers := make([]interface{}, 0, len(cols)) fieldTag := make(map[string]reflect.Value, len(cols)) @@ -179,6 +181,13 @@ func structPointers(sliceItem reflect.Value, cols []string, strict bool) []inter fieldVal = reflect.ValueOf(nil) } else { fieldVal = sliceItem.FieldByName(ScannerMapper(colName)) + if fieldVal == (reflect.Value{}) { + if sliceItem.Addr().Type().Implements(sqlScannerType) { + // probably this is a custom struct that implements sql.Scanner. + // do our best and don't set "nothing" as a pointer + fieldVal = sliceItem + } + } } } if !fieldVal.IsValid() || !fieldVal.CanSet() { diff --git a/sql_scanner_test.go b/sql_scanner_test.go new file mode 100644 index 0000000..91b1fab --- /dev/null +++ b/sql_scanner_test.go @@ -0,0 +1,72 @@ +package scan_test + +import ( + "fmt" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/blockloop/scan/v2" +) + +func TestCustomScanner(t *testing.T) { + db := mustDB(t.Name(), ` + CREATE TABLE test + ( + id int PRIMARY KEY, + data int NOT NULL + )`, + `INSERT INTO test (id, data) VALUES (1, 123), (2, 234)`, + ) + t.Cleanup(func() { db.Close() }) + + const ( + selectOneQuery = `SELECT data FROM test WHERE id = 1` + selectAllQuery = `SELECT data FROM test` + ) + + t.Run("scan.Row must work", func(t *testing.T) { + var data customScanner + rows, err := db.Query(selectOneQuery) + require.NoError(t, err) + + err = scan.Row(&data, rows) + require.NoError(t, err) + require.Equal(t, customScanner{v: 123}, data) + }) + + t.Run("scan.Rows must work", func(t *testing.T) { + var data []customScanner + rows, err := db.Query(selectAllQuery) + require.NoError(t, err) + + err = scan.Rows(&data, rows) + require.NoError(t, err) + require.ElementsMatch(t, []customScanner{{v: 123}, {v: 234}}, data) + }) + +} + +type customScanner struct { + v int64 +} + +func (c *customScanner) Scan(src interface{}) error { + switch v := src.(type) { + case int64: + *c = customScanner{v: v} + case []byte: // for ramsql + n, err := strconv.ParseInt(string(v), 10, 64) + if err != nil { + return fmt.Errorf("parse int: %w", err) + } + *c = customScanner{v: n} + case nil: + return nil + default: + return fmt.Errorf("unsupported type %T", src) + } + + return nil +}