From d3876d43bbbeb98b4d67309b4d3eef49438ef74c Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Sun, 28 Jul 2019 20:03:53 +0100 Subject: [PATCH] fix(scan): Support Scan into pointer to RedisScan Add support for scanning into a pointer to a type which supports RedisScan. Fixes #418 --- redis/scan.go | 20 +++++++++++++++ redis/scan_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/redis/scan.go b/redis/scan.go index 5227aac9..135d9639 100644 --- a/redis/scan.go +++ b/redis/scan.go @@ -115,6 +115,26 @@ func convertAssignBulkString(d reflect.Value, s []byte) (err error) { } else { err = cannotConvert(d, s) } + case reflect.Ptr: + if d.CanInterface() && d.CanSet() { + if s == nil { + if d.IsNil() { + return nil + } + + d.Set(reflect.Zero(d.Type())) + return nil + } + + if d.IsNil() { + d.Set(reflect.New(d.Type().Elem())) + } + + if sc, ok := d.Interface().(Scanner); ok { + return sc.RedisScan(s) + } + } + err = convertAssignString(d, string(s)) default: err = convertAssignString(d, string(s)) } diff --git a/redis/scan_test.go b/redis/scan_test.go index 3e35242b..4c61837b 100644 --- a/redis/scan_test.go +++ b/redis/scan_test.go @@ -18,10 +18,12 @@ import ( "fmt" "math" "reflect" + "strconv" "testing" "time" "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/require" ) type durationScan struct { @@ -457,6 +459,68 @@ func TestArgs(t *testing.T) { } } +type InnerStruct struct { + Foo int64 +} + +func (f *InnerStruct) RedisScan(src interface{}) (err error) { + switch s := src.(type) { + case []byte: + f.Foo, err = strconv.ParseInt(string(s), 10, 64) + case string: + f.Foo, err = strconv.ParseInt(s, 10, 64) + default: + return fmt.Errorf("invalid type %T", src) + } + return err +} + +type OuterStruct struct { + Inner *InnerStruct +} + +func TestScanPtrRedisScan(t *testing.T) { + tests := []struct { + name string + src []interface{} + dest OuterStruct + expected OuterStruct + }{ + { + name: "value-to-nil", + src: []interface{}{[]byte("1234"), nil}, + dest: OuterStruct{&InnerStruct{}}, + expected: OuterStruct{Inner: &InnerStruct{Foo: 1234}}, + }, + { + name: "nil-to-nil", + src: []interface{}{[]byte(nil), nil}, + dest: OuterStruct{}, + expected: OuterStruct{}, + }, + { + name: "value-to-value", + src: []interface{}{[]byte("1234"), nil}, + dest: OuterStruct{Inner: &InnerStruct{Foo: 5678}}, + expected: OuterStruct{Inner: &InnerStruct{Foo: 1234}}, + }, + { + name: "nil-to-value", + src: []interface{}{[]byte(nil), nil}, + dest: OuterStruct{Inner: &InnerStruct{Foo: 1234}}, + expected: OuterStruct{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := redis.Scan(tc.src, &tc.dest.Inner) + require.NoError(t, err) + require.Equal(t, tc.expected, tc.dest) + }) + } +} + func ExampleArgs() { c, err := dial() if err != nil {