From c93a42b40a772c2e0faf50d7f1732a0738d24368 Mon Sep 17 00:00:00 2001 From: Zdenek Tison Date: Mon, 20 Mar 2023 12:07:38 +0100 Subject: [PATCH] Custom type marshal/unmarshal functions --- feature/dynamodb/attributevalue/decode.go | 79 +++++++++++++-- .../dynamodb/attributevalue/decode_test.go | 39 ++++++++ feature/dynamodb/attributevalue/encode.go | 84 ++++++++++++---- .../dynamodb/attributevalue/encode_test.go | 40 ++++++++ .../dynamodb/attributevalue/marshaler_test.go | 97 +++++++++++++++++++ .../dynamodb/attributevalue/shared_test.go | 65 +++++++++++++ .../dynamodbstreams/attributevalue/decode.go | 77 +++++++++++++-- .../attributevalue/decode_test.go | 40 ++++++++ .../dynamodbstreams/attributevalue/encode.go | 83 ++++++++++++---- .../attributevalue/encode_test.go | 40 ++++++++ .../attributevalue/marshaler_test.go | 97 +++++++++++++++++++ .../attributevalue/shared_test.go | 65 +++++++++++++ 12 files changed, 749 insertions(+), 57 deletions(-) diff --git a/feature/dynamodb/attributevalue/decode.go b/feature/dynamodb/attributevalue/decode.go index a0a8b53f76f..cbc7e7357c1 100644 --- a/feature/dynamodb/attributevalue/decode.go +++ b/feature/dynamodb/attributevalue/decode.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "strconv" + "sync" "time" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" @@ -235,7 +236,9 @@ type DecoderOptions struct { // A Decoder provides unmarshaling AttributeValues to Go value types. type Decoder struct { - options DecoderOptions + options DecoderOptions + unmarshalersLock sync.RWMutex + unmarshalers map[reflect.Type]func(types.AttributeValue) (interface{}, error) } // NewDecoder creates a new Decoder with default configuration. Use @@ -279,6 +282,34 @@ func (d *Decoder) Decode(av types.AttributeValue, out interface{}, opts ...func( return d.decode(av, v, tag{}) } +// RegisterUnmarshaler registers a custom unmarshaler to use for a provided type. +// +// Precedence is given to registered unmarshalers that operate on concrete types, +// then the UnmarshalDynamoDBAttributeValue method, and lastly the default behavior of Decode. +func (d *Decoder) RegisterUnmarshaler(t reflect.Type, fn func(types.AttributeValue) (interface{}, error)) error { + if t == nil { + return fmt.Errorf("type can't be nil") + } + if fn == nil { + return fmt.Errorf("unmarshaler function can't be nil") + } + switch t.Kind() { + case reflect.Ptr, reflect.Chan, reflect.Invalid, reflect.Func, reflect.UnsafePointer: + return fmt.Errorf("not supported kind %q", t.Kind()) + default: + d.unmarshalersLock.Lock() + defer d.unmarshalersLock.Unlock() + if d.unmarshalers == nil { + d.unmarshalers = map[reflect.Type]func(types.AttributeValue) (interface{}, error){t: fn} + } else if _, ok := d.unmarshalers[t]; ok { + return fmt.Errorf("unmarshaler has already been registered for type %q", t) + } else { + d.unmarshalers[t] = fn + } + return nil + } +} + var stringInterfaceMapType = reflect.TypeOf(map[string]interface{}(nil)) var byteSliceType = reflect.TypeOf([]byte(nil)) var byteSliceSliceType = reflect.TypeOf([][]byte(nil)) @@ -288,14 +319,14 @@ func (d *Decoder) decode(av types.AttributeValue, v reflect.Value, fieldTag tag) var u Unmarshaler _, isNull := av.(*types.AttributeValueMemberNULL) if av == nil || isNull { - u, v = indirect(v, indirectOptions{decodeNull: true}) + u, v = d.indirect(v, indirectOptions{decodeNull: true}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(av) } return d.decodeNull(v) } - u, v = indirect(v, indirectOptions{}) + u, v = d.indirect(v, indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(av) } @@ -420,7 +451,7 @@ func (d *Decoder) decodeBinarySet(bs [][]byte, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), indirectOptions{}) + u, elem := d.indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberBS{Value: bs}) } @@ -555,7 +586,7 @@ func (d *Decoder) decodeNumberSet(ns []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), indirectOptions{}) + u, elem := d.indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberNS{Value: ns}) } @@ -634,7 +665,7 @@ func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Val for k, av := range avMap { key := reflect.New(keyType).Elem() // handle pointer keys - _, indirectKey := indirect(key, indirectOptions{skipUnmarshaler: true}) + _, indirectKey := d.indirect(key, indirectOptions{skipUnmarshaler: true}) if err := decodeMapKey(k, indirectKey, tag{}); err != nil { return &UnmarshalTypeError{ Value: fmt.Sprintf("map key %q", k), @@ -777,7 +808,7 @@ func (d *Decoder) decodeStringSet(ss []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), indirectOptions{}) + u, elem := d.indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberSS{Value: ss}) } @@ -820,12 +851,26 @@ type indirectOptions struct { skipUnmarshaler bool } +type typeUnmarshaler struct { + fn func(types.AttributeValue) (interface{}, error) + in reflect.Value +} + +func (u typeUnmarshaler) UnmarshalDynamoDBAttributeValue(av types.AttributeValue) error { + if out, err := u.fn(av); err != nil { + return err + } else if out != nil { + u.in.Set(reflect.ValueOf(out)) + } + return nil +} + // indirect will walk a value's interface or pointer value types. Returning // the final value or the value a unmarshaler is defined on. // // Based on the enoding/json type reflect value type indirection in Go Stdlib // https://golang.org/src/encoding/json/decode.go indirect func. -func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) { +func (d *Decoder) indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) { // Issue #24153 indicates that it is generally not a guaranteed property // that you may round-trip a reflect.Value by calling Value.Addr().Elem() // and expect the value to still be settable for values derived from @@ -879,6 +924,11 @@ func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } + + if u := d.getTypeUnmarshaler(v); u != nil { + return u, reflect.Value{} + } + if !opts.skipUnmarshaler && v.Type().NumMethod() > 0 && v.CanInterface() { if u, ok := v.Interface().(Unmarshaler); ok { return u, reflect.Value{} @@ -896,6 +946,19 @@ func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value return nil, v } +func (d *Decoder) getTypeUnmarshaler(v reflect.Value) Unmarshaler { + v = valueElem(v) + if v.Kind() == reflect.Invalid { + return nil + } + d.unmarshalersLock.RLock() + defer d.unmarshalersLock.RUnlock() + if fn, ok := d.unmarshalers[v.Type()]; ok { + return typeUnmarshaler{fn, v} + } + return nil +} + // A Number represents a Attributevalue number literal. type Number string diff --git a/feature/dynamodb/attributevalue/decode_test.go b/feature/dynamodb/attributevalue/decode_test.go index 229eec6609c..ad1f1fafec2 100644 --- a/feature/dynamodb/attributevalue/decode_test.go +++ b/feature/dynamodb/attributevalue/decode_test.go @@ -1172,5 +1172,44 @@ func TestUnmarshalMap_keyPtrTypes(t *testing.T) { t.Errorf("expect %v key not found", *k) } } +} +func TestDecoderTypeUnmarshalers(t *testing.T) { + for name, c := range sharedTypeMarshalersTestCases { + t.Run(name, func(t *testing.T) { + called := false + dec := NewDecoder() + err := dec.RegisterUnmarshaler(reflect.TypeOf(c.expected), func(value types.AttributeValue) (interface{}, error) { + called = true + return c.expected, nil + }) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + err = dec.Decode(c.in, c.actual) + if !called { + t.Fatalf("expected unmarshaler to be called") + } + assertConvertTest(t, c.actual, c.expected, err, nil) + }) + } +} + +func TestDecoderNotSupportedUnmarshalerType(t *testing.T) { + cases := map[string]reflect.Type{ + "pointer": reflect.TypeOf(new(string)), + "channel": reflect.TypeOf(make(chan int)), + "func": reflect.TypeOf(func() {}), + } + + for name, caseType := range cases { + t.Run(name, func(t *testing.T) { + err := NewDecoder().RegisterUnmarshaler(caseType, func(value types.AttributeValue) (interface{}, error) { + return nil, nil + }) + if err == nil { + t.Errorf("expect error when registering unmarshaler for unsupported type %q", caseType) + } + }) + } } diff --git a/feature/dynamodb/attributevalue/encode.go b/feature/dynamodb/attributevalue/encode.go index f8b2246c894..c97274383ad 100644 --- a/feature/dynamodb/attributevalue/encode.go +++ b/feature/dynamodb/attributevalue/encode.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "strconv" + "sync" "time" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" @@ -384,7 +385,9 @@ type EncoderOptions struct { // An Encoder provides marshaling Go value types to AttributeValues. type Encoder struct { - options EncoderOptions + options EncoderOptions + marshalersLock sync.RWMutex + marshalers map[reflect.Type]func(interface{}) (types.AttributeValue, error) } // NewEncoder creates a new Encoder with default configuration. Use @@ -414,6 +417,34 @@ func (e *Encoder) Encode(in interface{}) (types.AttributeValue, error) { return e.encode(reflect.ValueOf(in), tag{}) } +// RegisterMarshaler registers a custom marshaler to use for a provided type. +// +// Precedence is given to registered marshalers that operate on concrete types, +// then the MarshalDynamoDBAttributeValue method, and lastly the default behavior of Encode. +func (e *Encoder) RegisterMarshaler(t reflect.Type, fn func(interface{}) (types.AttributeValue, error)) error { + if t == nil { + return fmt.Errorf("type can't be nil") + } + if fn == nil { + return fmt.Errorf("marshaller function can't be nil") + } + switch t.Kind() { + case reflect.Ptr, reflect.Chan, reflect.Invalid, reflect.Func, reflect.UnsafePointer: + return fmt.Errorf("not supported kind %q", t.Kind()) + default: + e.marshalersLock.Lock() + defer e.marshalersLock.Unlock() + if e.marshalers == nil { + e.marshalers = map[reflect.Type]func(interface{}) (types.AttributeValue, error){t: fn} + } else if _, ok := e.marshalers[t]; ok { + return fmt.Errorf("marshaler has already been registered for type %q", t) + } else { + e.marshalers[t] = fn + } + return nil + } +} + func (e *Encoder) encode(v reflect.Value, fieldTag tag) (types.AttributeValue, error) { // Ignore fields explicitly marked to be skipped. if fieldTag.Ignore { @@ -433,12 +464,11 @@ func (e *Encoder) encode(v reflect.Value, fieldTag tag) (types.AttributeValue, e return encodeNull(), nil } } - // Handle both pointers and interface conversion into types v = valueElem(v) if v.Kind() != reflect.Invalid { - if av, err := tryMarshaler(v); err != nil { + if av, err := e.tryMarshaler(v); err != nil { return nil, err } else if av != nil { return av, nil @@ -714,7 +744,7 @@ func (e *Encoder) encodeScalar(v reflect.Value, fieldTag tag) (types.AttributeVa } func (e *Encoder) encodeNumber(v reflect.Value) (types.AttributeValue, error) { - if av, err := tryMarshaler(v); err != nil { + if av, err := e.tryMarshaler(v); err != nil { return nil, err } else if av != nil { return av, nil @@ -742,7 +772,7 @@ func (e *Encoder) encodeNumber(v reflect.Value) (types.AttributeValue, error) { } func (e *Encoder) encodeString(v reflect.Value) (types.AttributeValue, error) { - if av, err := tryMarshaler(v); err != nil { + if av, err := e.tryMarshaler(v); err != nil { return nil, err } else if av != nil { return av, nil @@ -758,6 +788,34 @@ func (e *Encoder) encodeString(v reflect.Value) (types.AttributeValue, error) { } } +func (e *Encoder) tryMarshaler(v reflect.Value) (types.AttributeValue, error) { + if av, err := e.tryTypeMarshaler(v); err != nil { + return nil, err + } else if av != nil { + return av, nil + } + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + v = v.Addr() + } + if v.Type().NumMethod() == 0 { + return nil, nil + } + if m, ok := v.Interface().(Marshaler); ok { + return m.MarshalDynamoDBAttributeValue() + } + + return nil, nil +} + +func (e *Encoder) tryTypeMarshaler(v reflect.Value) (types.AttributeValue, error) { + e.marshalersLock.RLock() + defer e.marshalersLock.RUnlock() + if m, ok := e.marshalers[v.Type()]; ok { + return m(v.Interface()) + } + return nil, nil +} + func encodeInt(i int64) string { return strconv.FormatInt(i, 10) } @@ -832,22 +890,6 @@ func isNullableZeroValue(v reflect.Value) bool { return false } -func tryMarshaler(v reflect.Value) (types.AttributeValue, error) { - if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { - v = v.Addr() - } - - if v.Type().NumMethod() == 0 { - return nil, nil - } - - if m, ok := v.Interface().(Marshaler); ok { - return m.MarshalDynamoDBAttributeValue() - } - - return nil, nil -} - // An InvalidMarshalError is an error type representing an error // occurring when marshaling a Go value type to an AttributeValue. type InvalidMarshalError struct { diff --git a/feature/dynamodb/attributevalue/encode_test.go b/feature/dynamodb/attributevalue/encode_test.go index a997331572d..4ab477ecb7c 100644 --- a/feature/dynamodb/attributevalue/encode_test.go +++ b/feature/dynamodb/attributevalue/encode_test.go @@ -619,3 +619,43 @@ func TestMarshalMap_keyTypes(t *testing.T) { }) } } + +func TestEncoderTypesMarshalers(t *testing.T) { + for name, c := range sharedTypeMarshalersTestCases { + t.Run(name, func(t *testing.T) { + called := false + enc := NewEncoder() + err := enc.RegisterMarshaler(reflect.TypeOf(c.expected), func(i interface{}) (types.AttributeValue, error) { + called = true + return c.in, nil + }) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + av, err := enc.Encode(c.expected) + if !called { + t.Fatalf("expected marshaler to be called") + } + assertConvertTest(t, av, c.in, err, nil) + }) + } +} + +func TestEncoderNotSupportedMarshalerType(t *testing.T) { + cases := map[string]reflect.Type{ + "pointer": reflect.TypeOf(new(string)), + "channel": reflect.TypeOf(make(chan int)), + "func": reflect.TypeOf(func() {}), + } + + for name, caseType := range cases { + t.Run(name, func(t *testing.T) { + err := NewEncoder().RegisterMarshaler(caseType, func(i interface{}) (types.AttributeValue, error) { + return nil, nil + }) + if err == nil { + t.Errorf("expect error when registering marshaler for unsupported type %q", caseType) + } + }) + } +} diff --git a/feature/dynamodb/attributevalue/marshaler_test.go b/feature/dynamodb/attributevalue/marshaler_test.go index fa810728ea9..5f58f393604 100644 --- a/feature/dynamodb/attributevalue/marshaler_test.go +++ b/feature/dynamodb/attributevalue/marshaler_test.go @@ -709,3 +709,100 @@ func Test_Encode_YAML_TagKey(t *testing.T) { compareObjects(t, expected, actual) } + +func BenchmarkEncoderTypeMarshaler(b *testing.B) { + fieldCache = &fieldCacher{} + + simple := simpleMarshalStruct{ + String: "abc", + Int: 123, + Uint: 123, + Float32: 123.321, + Float64: 123.321, + Bool: true, + } + + type MyCompositeStruct struct { + A simpleMarshalStruct `dynamodbav:"a"` + } + + var marshalerFn = func(value string) func(i interface{}) (types.AttributeValue, error) { + return func(interface{}) (types.AttributeValue, error) { + return &types.AttributeValueMemberS{Value: value}, nil + } + } + + fns := map[reflect.Type]func(i interface{}) (types.AttributeValue, error){ + reflect.TypeOf("abc"): marshalerFn("abc"), + reflect.TypeOf(123): marshalerFn("123"), + reflect.TypeOf(uint(123)): marshalerFn("uint(123)"), + reflect.TypeOf(float32(123.321)): marshalerFn("float32(123.321)"), + reflect.TypeOf(123.321): marshalerFn("123.321"), + reflect.TypeOf(true): marshalerFn("true"), + } + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + enc := NewEncoder() + for t, fn := range fns { + if err := enc.RegisterMarshaler(t, fn); err != nil { + b.Error("unexpected error:", err) + } + } + if _, err := enc.Encode(MyCompositeStruct{ + A: simple, + }); err != nil { + b.Error("unexpected error:", err) + } + } + }) +} + +func BenchmarkDecoderTypeUnmarshaler(b *testing.B) { + myStructAVMap, _ := Marshal(simpleMarshalStruct{ + String: "abc", + Int: 123, + Uint: 123, + Float32: 123.321, + Float64: 123.321, + Bool: true, + }) + + type MyCompositeStructOne struct { + A simpleMarshalStruct `dynamodbav:"a"` + } + + var unmarshalerFn = func(value interface{}) func(types.AttributeValue) (interface{}, error) { + return func(types.AttributeValue) (interface{}, error) { + return value, nil + } + } + + fns := map[reflect.Type]func(types.AttributeValue) (interface{}, error){ + reflect.TypeOf("abc"): unmarshalerFn("abc"), + reflect.TypeOf(123): unmarshalerFn(123), + reflect.TypeOf(uint(123)): unmarshalerFn(uint(123)), + reflect.TypeOf(float32(123.321)): unmarshalerFn(float32(123.321)), + reflect.TypeOf(123.321): unmarshalerFn(123.321), + reflect.TypeOf(true): unmarshalerFn(true), + } + + var out MyCompositeStructOne + avMap := map[string]types.AttributeValue{ + "a": myStructAVMap, + } + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + dec := NewDecoder() + for t, fn := range fns { + if err := dec.RegisterUnmarshaler(t, fn); err != nil { + b.Error("unexpected error:", err) + } + } + if err := dec.Decode(&types.AttributeValueMemberM{Value: avMap}, &out); err != nil { + b.Error("unexpected error:", err) + } + } + }) +} diff --git a/feature/dynamodb/attributevalue/shared_test.go b/feature/dynamodb/attributevalue/shared_test.go index 544e99585cb..f39d106a595 100644 --- a/feature/dynamodb/attributevalue/shared_test.go +++ b/feature/dynamodb/attributevalue/shared_test.go @@ -395,6 +395,71 @@ var sharedMapTestCases = map[string]struct { }, } +var sharedTypeMarshalersTestCases = map[string]struct { + in types.AttributeValue + actual, expected interface{} +}{ + "binary slice": { + in: &types.AttributeValueMemberS{Value: "[]byte{48, 49}"}, + actual: &[]byte{}, + expected: []byte{48, 49}, + }, + "binary slice pointer": { + in: &types.AttributeValueMemberS{Value: "[]byte{48, 49}"}, + actual: func() **[]byte { + v := make([]byte, 0, 10) + v2 := &v + return &v2 + }(), + expected: []byte{48, 49}, + }, + "bool": { + in: &types.AttributeValueMemberS{Value: "true"}, + actual: new(bool), + expected: true, + }, + "list": { + in: &types.AttributeValueMemberS{Value: "[123]"}, + actual: &[]int{}, + expected: []int{123}, + }, + "list, interface": { + in: &types.AttributeValueMemberS{Value: "1, 2, 3"}, + actual: &[]interface{}{}, + expected: []interface{}{1, "2", 3}, + }, + "map, interface": { + in: &types.AttributeValueMemberS{Value: "{\"abc\": 123}"}, + actual: &map[string]int{}, + expected: map[string]int{"abc": 123}, + }, + "map, struct": { + in: &types.AttributeValueMemberS{Value: "{\"ABC\": 123}"}, + actual: &struct{ Abc int }{}, + expected: struct{ Abc int }{Abc: 123}, + }, + "int": { + in: &types.AttributeValueMemberS{Value: "123"}, + actual: new(int), + expected: 123, + }, + "float": { + in: &types.AttributeValueMemberS{Value: "123.1"}, + actual: new(float64), + expected: 123.1, + }, + "string": { + in: &types.AttributeValueMemberSS{Value: []string{"abc"}}, + actual: new(string), + expected: "abc", + }, + "aliased string": { + in: &types.AttributeValueMemberSS{Value: []string{"abc"}}, + actual: new(testAliasedString), + expected: testAliasedString("abc"), + }, +} + func assertConvertTest(t *testing.T, actual, expected interface{}, err, expectedErr error) { t.Helper() diff --git a/feature/dynamodbstreams/attributevalue/decode.go b/feature/dynamodbstreams/attributevalue/decode.go index b722b3eec38..567f9d808ba 100644 --- a/feature/dynamodbstreams/attributevalue/decode.go +++ b/feature/dynamodbstreams/attributevalue/decode.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "strconv" + "sync" "time" "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams/types" @@ -235,7 +236,9 @@ type DecoderOptions struct { // A Decoder provides unmarshaling AttributeValues to Go value types. type Decoder struct { - options DecoderOptions + options DecoderOptions + unmarshalersLock sync.RWMutex + unmarshalers map[reflect.Type]func(types.AttributeValue) (interface{}, error) } // NewDecoder creates a new Decoder with default configuration. Use @@ -288,14 +291,14 @@ func (d *Decoder) decode(av types.AttributeValue, v reflect.Value, fieldTag tag) var u Unmarshaler _, isNull := av.(*types.AttributeValueMemberNULL) if av == nil || isNull { - u, v = indirect(v, indirectOptions{decodeNull: true}) + u, v = d.indirect(v, indirectOptions{decodeNull: true}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(av) } return d.decodeNull(v) } - u, v = indirect(v, indirectOptions{}) + u, v = d.indirect(v, indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(av) } @@ -333,6 +336,34 @@ func (d *Decoder) decode(av types.AttributeValue, v reflect.Value, fieldTag tag) } } +// RegisterUnmarshaler registers a custom unmarshaler to use for a provided type. +// +// Precedence is given to registered unmarshalers that operate on concrete types, +// then the UnmarshalDynamoDBStreamsAttributeValue method, and lastly the default behavior of Decode. +func (d *Decoder) RegisterUnmarshaler(t reflect.Type, fn func(types.AttributeValue) (interface{}, error)) error { + if t == nil { + return fmt.Errorf("type can't be nil") + } + if fn == nil { + return fmt.Errorf("unmarshaler function can't be nil") + } + switch t.Kind() { + case reflect.Ptr, reflect.Chan, reflect.Invalid, reflect.Func, reflect.UnsafePointer: + return fmt.Errorf("not supported kind %q", t.Kind()) + default: + d.unmarshalersLock.Lock() + defer d.unmarshalersLock.Unlock() + if d.unmarshalers == nil { + d.unmarshalers = map[reflect.Type]func(types.AttributeValue) (interface{}, error){t: fn} + } else if _, ok := d.unmarshalers[t]; ok { + return fmt.Errorf("unmarshaler has already been registered for type %q", t) + } else { + d.unmarshalers[t] = fn + } + return nil + } +} + func (d *Decoder) decodeBinary(b []byte, v reflect.Value) error { if v.Kind() == reflect.Interface { buf := make([]byte, len(b)) @@ -420,7 +451,7 @@ func (d *Decoder) decodeBinarySet(bs [][]byte, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), indirectOptions{}) + u, elem := d.indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(&types.AttributeValueMemberBS{Value: bs}) } @@ -555,7 +586,7 @@ func (d *Decoder) decodeNumberSet(ns []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), indirectOptions{}) + u, elem := d.indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(&types.AttributeValueMemberNS{Value: ns}) } @@ -634,7 +665,7 @@ func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Val for k, av := range avMap { key := reflect.New(keyType).Elem() // handle pointer keys - _, indirectKey := indirect(key, indirectOptions{skipUnmarshaler: true}) + _, indirectKey := d.indirect(key, indirectOptions{skipUnmarshaler: true}) if err := decodeMapKey(k, indirectKey, tag{}); err != nil { return &UnmarshalTypeError{ Value: fmt.Sprintf("map key %q", k), @@ -777,7 +808,7 @@ func (d *Decoder) decodeStringSet(ss []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), indirectOptions{}) + u, elem := d.indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(&types.AttributeValueMemberSS{Value: ss}) } @@ -820,12 +851,26 @@ type indirectOptions struct { skipUnmarshaler bool } +type typeUnmarshaler struct { + fn func(types.AttributeValue) (interface{}, error) + in reflect.Value +} + +func (u typeUnmarshaler) UnmarshalDynamoDBStreamsAttributeValue(av types.AttributeValue) error { + if out, err := u.fn(av); err != nil { + return err + } else if out != nil { + u.in.Set(reflect.ValueOf(out)) + } + return nil +} + // indirect will walk a value's interface or pointer value types. Returning // the final value or the value a unmarshaler is defined on. // // Based on the enoding/json type reflect value type indirection in Go Stdlib // https://golang.org/src/encoding/json/decode.go indirect func. -func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) { +func (d *Decoder) indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) { // Issue #24153 indicates that it is generally not a guaranteed property // that you may round-trip a reflect.Value by calling Value.Addr().Elem() // and expect the value to still be settable for values derived from @@ -879,6 +924,9 @@ func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } + if u := d.getTypeUnmarshaler(v); u != nil { + return u, reflect.Value{} + } if !opts.skipUnmarshaler && v.Type().NumMethod() > 0 && v.CanInterface() { if u, ok := v.Interface().(Unmarshaler); ok { return u, reflect.Value{} @@ -896,6 +944,19 @@ func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value return nil, v } +func (d *Decoder) getTypeUnmarshaler(v reflect.Value) Unmarshaler { + v = valueElem(v) + if v.Kind() == reflect.Invalid { + return nil + } + d.unmarshalersLock.RLock() + defer d.unmarshalersLock.RUnlock() + if fn, ok := d.unmarshalers[v.Type()]; ok { + return typeUnmarshaler{fn, v} + } + return nil +} + // A Number represents a Attributevalue number literal. type Number string diff --git a/feature/dynamodbstreams/attributevalue/decode_test.go b/feature/dynamodbstreams/attributevalue/decode_test.go index 2d6ea09244f..fe0d8f6a5a1 100644 --- a/feature/dynamodbstreams/attributevalue/decode_test.go +++ b/feature/dynamodbstreams/attributevalue/decode_test.go @@ -1174,3 +1174,43 @@ func TestUnmarshalMap_keyPtrTypes(t *testing.T) { } } + +func TestDecoderTypeUnmarshalers(t *testing.T) { + for name, c := range sharedTypeMarshalersTestCases { + t.Run(name, func(t *testing.T) { + called := false + dec := NewDecoder() + err := dec.RegisterUnmarshaler(reflect.TypeOf(c.expected), func(value types.AttributeValue) (interface{}, error) { + called = true + return c.expected, nil + }) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + err = dec.Decode(c.in, c.actual) + if !called { + t.Fatalf("expected unmarshaler to be called") + } + assertConvertTest(t, c.actual, c.expected, err, nil) + }) + } +} + +func TestDecoderNotSupportedUnmarshalerType(t *testing.T) { + cases := map[string]reflect.Type{ + "pointer": reflect.TypeOf(new(string)), + "channel": reflect.TypeOf(make(chan int)), + "func": reflect.TypeOf(func() {}), + } + + for name, caseType := range cases { + t.Run(name, func(t *testing.T) { + err := NewDecoder().RegisterUnmarshaler(caseType, func(value types.AttributeValue) (interface{}, error) { + return nil, nil + }) + if err == nil { + t.Errorf("expect error when registering unmarshaler for unsupported type %q", caseType) + } + }) + } +} diff --git a/feature/dynamodbstreams/attributevalue/encode.go b/feature/dynamodbstreams/attributevalue/encode.go index 28214080a09..40568fc7ec6 100644 --- a/feature/dynamodbstreams/attributevalue/encode.go +++ b/feature/dynamodbstreams/attributevalue/encode.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "strconv" + "sync" "time" "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams/types" @@ -384,7 +385,9 @@ type EncoderOptions struct { // An Encoder provides marshaling Go value types to AttributeValues. type Encoder struct { - options EncoderOptions + options EncoderOptions + marshalersLock sync.RWMutex + marshalers map[reflect.Type]func(interface{}) (types.AttributeValue, error) } // NewEncoder creates a new Encoder with default configuration. Use @@ -414,6 +417,34 @@ func (e *Encoder) Encode(in interface{}) (types.AttributeValue, error) { return e.encode(reflect.ValueOf(in), tag{}) } +// RegisterMarshaler registers a custom marshaler to use for a provided type. +// +// Precedence is given to registered marshalers that operate on concrete types, +// then the MarshalDynamoDBStreamsAttributeValue method, and lastly the default behavior of Encode. +func (e *Encoder) RegisterMarshaler(t reflect.Type, fn func(interface{}) (types.AttributeValue, error)) error { + if t == nil { + return fmt.Errorf("type can't be nil") + } + if fn == nil { + return fmt.Errorf("marshaller function can't be nil") + } + switch t.Kind() { + case reflect.Ptr, reflect.Chan, reflect.Invalid, reflect.Func, reflect.UnsafePointer: + return fmt.Errorf("not supported kind %q", t.Kind()) + default: + e.marshalersLock.Lock() + defer e.marshalersLock.Unlock() + if e.marshalers == nil { + e.marshalers = map[reflect.Type]func(interface{}) (types.AttributeValue, error){t: fn} + } else if _, ok := e.marshalers[t]; ok { + return fmt.Errorf("marshaler has already been registered for type %q", t) + } else { + e.marshalers[t] = fn + } + return nil + } +} + func (e *Encoder) encode(v reflect.Value, fieldTag tag) (types.AttributeValue, error) { // Ignore fields explicitly marked to be skipped. if fieldTag.Ignore { @@ -438,7 +469,7 @@ func (e *Encoder) encode(v reflect.Value, fieldTag tag) (types.AttributeValue, e v = valueElem(v) if v.Kind() != reflect.Invalid { - if av, err := tryMarshaler(v); err != nil { + if av, err := e.tryMarshaler(v); err != nil { return nil, err } else if av != nil { return av, nil @@ -714,7 +745,7 @@ func (e *Encoder) encodeScalar(v reflect.Value, fieldTag tag) (types.AttributeVa } func (e *Encoder) encodeNumber(v reflect.Value) (types.AttributeValue, error) { - if av, err := tryMarshaler(v); err != nil { + if av, err := e.tryMarshaler(v); err != nil { return nil, err } else if av != nil { return av, nil @@ -742,7 +773,7 @@ func (e *Encoder) encodeNumber(v reflect.Value) (types.AttributeValue, error) { } func (e *Encoder) encodeString(v reflect.Value) (types.AttributeValue, error) { - if av, err := tryMarshaler(v); err != nil { + if av, err := e.tryMarshaler(v); err != nil { return nil, err } else if av != nil { return av, nil @@ -758,6 +789,34 @@ func (e *Encoder) encodeString(v reflect.Value) (types.AttributeValue, error) { } } +func (e *Encoder) tryMarshaler(v reflect.Value) (types.AttributeValue, error) { + if av, err := e.tryTypeMarshaler(v); err != nil { + return nil, err + } else if av != nil { + return av, nil + } + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + v = v.Addr() + } + if v.Type().NumMethod() == 0 { + return nil, nil + } + if m, ok := v.Interface().(Marshaler); ok { + return m.MarshalDynamoDBStreamsAttributeValue() + } + + return nil, nil +} + +func (e *Encoder) tryTypeMarshaler(v reflect.Value) (types.AttributeValue, error) { + e.marshalersLock.RLock() + defer e.marshalersLock.RUnlock() + if m, ok := e.marshalers[v.Type()]; ok { + return m(v.Interface()) + } + return nil, nil +} + func encodeInt(i int64) string { return strconv.FormatInt(i, 10) } @@ -832,22 +891,6 @@ func isNullableZeroValue(v reflect.Value) bool { return false } -func tryMarshaler(v reflect.Value) (types.AttributeValue, error) { - if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { - v = v.Addr() - } - - if v.Type().NumMethod() == 0 { - return nil, nil - } - - if m, ok := v.Interface().(Marshaler); ok { - return m.MarshalDynamoDBStreamsAttributeValue() - } - - return nil, nil -} - // An InvalidMarshalError is an error type representing an error // occurring when marshaling a Go value type to an AttributeValue. type InvalidMarshalError struct { diff --git a/feature/dynamodbstreams/attributevalue/encode_test.go b/feature/dynamodbstreams/attributevalue/encode_test.go index 97c7399bd04..f99a2bff560 100644 --- a/feature/dynamodbstreams/attributevalue/encode_test.go +++ b/feature/dynamodbstreams/attributevalue/encode_test.go @@ -619,3 +619,43 @@ func TestMarshalMap_keyTypes(t *testing.T) { }) } } + +func TestEncoderTypesMarshalers(t *testing.T) { + for name, c := range sharedTypeMarshalersTestCases { + t.Run(name, func(t *testing.T) { + called := false + enc := NewEncoder() + err := enc.RegisterMarshaler(reflect.TypeOf(c.expected), func(i interface{}) (types.AttributeValue, error) { + called = true + return c.in, nil + }) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + av, err := enc.Encode(c.expected) + if !called { + t.Fatalf("expected marshaler to be called") + } + assertConvertTest(t, av, c.in, err, nil) + }) + } +} + +func TestEncoderNotSupportedMarshalerType(t *testing.T) { + cases := map[string]reflect.Type{ + "pointer": reflect.TypeOf(new(string)), + "channel": reflect.TypeOf(make(chan int)), + "func": reflect.TypeOf(func() {}), + } + + for name, caseType := range cases { + t.Run(name, func(t *testing.T) { + err := NewEncoder().RegisterMarshaler(caseType, func(i interface{}) (types.AttributeValue, error) { + return nil, nil + }) + if err == nil { + t.Errorf("expect error when registering marshaler for unsupported type %q", caseType) + } + }) + } +} diff --git a/feature/dynamodbstreams/attributevalue/marshaler_test.go b/feature/dynamodbstreams/attributevalue/marshaler_test.go index beb3d86e82d..bf414b83c15 100644 --- a/feature/dynamodbstreams/attributevalue/marshaler_test.go +++ b/feature/dynamodbstreams/attributevalue/marshaler_test.go @@ -709,3 +709,100 @@ func Test_Encode_YAML_TagKey(t *testing.T) { compareObjects(t, expected, actual) } + +func BenchmarkEncoderTypeMarshaler(b *testing.B) { + fieldCache = &fieldCacher{} + + simple := simpleMarshalStruct{ + String: "abc", + Int: 123, + Uint: 123, + Float32: 123.321, + Float64: 123.321, + Bool: true, + } + + type MyCompositeStruct struct { + A simpleMarshalStruct `dynamodbav:"a"` + } + + var marshalerFn = func(value string) func(i interface{}) (types.AttributeValue, error) { + return func(interface{}) (types.AttributeValue, error) { + return &types.AttributeValueMemberS{Value: value}, nil + } + } + + fns := map[reflect.Type]func(i interface{}) (types.AttributeValue, error){ + reflect.TypeOf("abc"): marshalerFn("abc"), + reflect.TypeOf(123): marshalerFn("123"), + reflect.TypeOf(uint(123)): marshalerFn("uint(123)"), + reflect.TypeOf(float32(123.321)): marshalerFn("float32(123.321)"), + reflect.TypeOf(123.321): marshalerFn("123.321"), + reflect.TypeOf(true): marshalerFn("true"), + } + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + enc := NewEncoder() + for t, fn := range fns { + if err := enc.RegisterMarshaler(t, fn); err != nil { + b.Error("unexpected error:", err) + } + } + if _, err := enc.Encode(MyCompositeStruct{ + A: simple, + }); err != nil { + b.Error("unexpected error:", err) + } + } + }) +} + +func BenchmarkDecoderTypeUnmarshaler(b *testing.B) { + myStructAVMap, _ := Marshal(simpleMarshalStruct{ + String: "abc", + Int: 123, + Uint: 123, + Float32: 123.321, + Float64: 123.321, + Bool: true, + }) + + type MyCompositeStructOne struct { + A simpleMarshalStruct `dynamodbav:"a"` + } + + var unmarshalerFn = func(value interface{}) func(types.AttributeValue) (interface{}, error) { + return func(types.AttributeValue) (interface{}, error) { + return value, nil + } + } + + fns := map[reflect.Type]func(types.AttributeValue) (interface{}, error){ + reflect.TypeOf("abc"): unmarshalerFn("abc"), + reflect.TypeOf(123): unmarshalerFn(123), + reflect.TypeOf(uint(123)): unmarshalerFn(uint(123)), + reflect.TypeOf(float32(123.321)): unmarshalerFn(float32(123.321)), + reflect.TypeOf(123.321): unmarshalerFn(123.321), + reflect.TypeOf(true): unmarshalerFn(true), + } + + var out MyCompositeStructOne + avMap := map[string]types.AttributeValue{ + "a": myStructAVMap, + } + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + dec := NewDecoder() + for t, fn := range fns { + if err := dec.RegisterUnmarshaler(t, fn); err != nil { + b.Error("unexpected error:", err) + } + } + if err := dec.Decode(&types.AttributeValueMemberM{Value: avMap}, &out); err != nil { + b.Error("unexpected error:", err) + } + } + }) +} diff --git a/feature/dynamodbstreams/attributevalue/shared_test.go b/feature/dynamodbstreams/attributevalue/shared_test.go index a8fc7263d4d..eff23f1f297 100644 --- a/feature/dynamodbstreams/attributevalue/shared_test.go +++ b/feature/dynamodbstreams/attributevalue/shared_test.go @@ -395,6 +395,71 @@ var sharedMapTestCases = map[string]struct { }, } +var sharedTypeMarshalersTestCases = map[string]struct { + in types.AttributeValue + actual, expected interface{} +}{ + "binary slice": { + in: &types.AttributeValueMemberS{Value: "[]byte{48, 49}"}, + actual: &[]byte{}, + expected: []byte{48, 49}, + }, + "binary slice pointer": { + in: &types.AttributeValueMemberS{Value: "[]byte{48, 49}"}, + actual: func() **[]byte { + v := make([]byte, 0, 10) + v2 := &v + return &v2 + }(), + expected: []byte{48, 49}, + }, + "bool": { + in: &types.AttributeValueMemberS{Value: "true"}, + actual: new(bool), + expected: true, + }, + "list": { + in: &types.AttributeValueMemberS{Value: "[123]"}, + actual: &[]int{}, + expected: []int{123}, + }, + "list, interface": { + in: &types.AttributeValueMemberS{Value: "1, 2, 3"}, + actual: &[]interface{}{}, + expected: []interface{}{1, "2", 3}, + }, + "map, interface": { + in: &types.AttributeValueMemberS{Value: "{\"abc\": 123}"}, + actual: &map[string]int{}, + expected: map[string]int{"abc": 123}, + }, + "map, struct": { + in: &types.AttributeValueMemberS{Value: "{\"ABC\": 123}"}, + actual: &struct{ Abc int }{}, + expected: struct{ Abc int }{Abc: 123}, + }, + "int": { + in: &types.AttributeValueMemberS{Value: "123"}, + actual: new(int), + expected: 123, + }, + "float": { + in: &types.AttributeValueMemberS{Value: "123.1"}, + actual: new(float64), + expected: 123.1, + }, + "string": { + in: &types.AttributeValueMemberSS{Value: []string{"abc"}}, + actual: new(string), + expected: "abc", + }, + "aliased string": { + in: &types.AttributeValueMemberSS{Value: []string{"abc"}}, + actual: new(testAliasedString), + expected: testAliasedString("abc"), + }, +} + func assertConvertTest(t *testing.T, actual, expected interface{}, err, expectedErr error) { t.Helper()