Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom type marshal/unmarshal functions #2060

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 71 additions & 8 deletions feature/dynamodb/attributevalue/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"reflect"
"strconv"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
Expand Down Expand Up @@ -235,7 +236,9 @@ type DecoderOptions struct {

// A Decoder provides unmarshaling AttributeValues to Go value types.
type Decoder struct {
options DecoderOptions
options DecoderOptions
unmarshalersLock sync.RWMutex
unmarshalers map[reflect.Type]func(types.AttributeValue) (interface{}, error)
}

// NewDecoder creates a new Decoder with default configuration. Use
Expand Down Expand Up @@ -279,6 +282,34 @@ func (d *Decoder) Decode(av types.AttributeValue, out interface{}, opts ...func(
return d.decode(av, v, tag{})
}

// RegisterUnmarshaler registers a custom unmarshaler to use for a provided type.
//
// Precedence is given to registered unmarshalers that operate on concrete types,
// then the UnmarshalDynamoDBAttributeValue method, and lastly the default behavior of Decode.
func (d *Decoder) RegisterUnmarshaler(t reflect.Type, fn func(types.AttributeValue) (interface{}, error)) error {
if t == nil {
return fmt.Errorf("type can't be nil")
}
if fn == nil {
return fmt.Errorf("unmarshaler function can't be nil")
}
switch t.Kind() {
case reflect.Ptr, reflect.Chan, reflect.Invalid, reflect.Func, reflect.UnsafePointer:
return fmt.Errorf("not supported kind %q", t.Kind())
default:
d.unmarshalersLock.Lock()
defer d.unmarshalersLock.Unlock()
if d.unmarshalers == nil {
d.unmarshalers = map[reflect.Type]func(types.AttributeValue) (interface{}, error){t: fn}
} else if _, ok := d.unmarshalers[t]; ok {
return fmt.Errorf("unmarshaler has already been registered for type %q", t)
} else {
d.unmarshalers[t] = fn
}
return nil
}
}

var stringInterfaceMapType = reflect.TypeOf(map[string]interface{}(nil))
var byteSliceType = reflect.TypeOf([]byte(nil))
var byteSliceSliceType = reflect.TypeOf([][]byte(nil))
Expand All @@ -288,14 +319,14 @@ func (d *Decoder) decode(av types.AttributeValue, v reflect.Value, fieldTag tag)
var u Unmarshaler
_, isNull := av.(*types.AttributeValueMemberNULL)
if av == nil || isNull {
u, v = indirect(v, indirectOptions{decodeNull: true})
u, v = d.indirect(v, indirectOptions{decodeNull: true})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(av)
}
return d.decodeNull(v)
}

u, v = indirect(v, indirectOptions{})
u, v = d.indirect(v, indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(av)
}
Expand Down Expand Up @@ -420,7 +451,7 @@ func (d *Decoder) decodeBinarySet(bs [][]byte, v reflect.Value) error {
if !isArray {
v.SetLen(i + 1)
}
u, elem := indirect(v.Index(i), indirectOptions{})
u, elem := d.indirect(v.Index(i), indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberBS{Value: bs})
}
Expand Down Expand Up @@ -555,7 +586,7 @@ func (d *Decoder) decodeNumberSet(ns []string, v reflect.Value) error {
if !isArray {
v.SetLen(i + 1)
}
u, elem := indirect(v.Index(i), indirectOptions{})
u, elem := d.indirect(v.Index(i), indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberNS{Value: ns})
}
Expand Down Expand Up @@ -634,7 +665,7 @@ func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Val
for k, av := range avMap {
key := reflect.New(keyType).Elem()
// handle pointer keys
_, indirectKey := indirect(key, indirectOptions{skipUnmarshaler: true})
_, indirectKey := d.indirect(key, indirectOptions{skipUnmarshaler: true})
if err := decodeMapKey(k, indirectKey, tag{}); err != nil {
return &UnmarshalTypeError{
Value: fmt.Sprintf("map key %q", k),
Expand Down Expand Up @@ -777,7 +808,7 @@ func (d *Decoder) decodeStringSet(ss []string, v reflect.Value) error {
if !isArray {
v.SetLen(i + 1)
}
u, elem := indirect(v.Index(i), indirectOptions{})
u, elem := d.indirect(v.Index(i), indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberSS{Value: ss})
}
Expand Down Expand Up @@ -820,12 +851,26 @@ type indirectOptions struct {
skipUnmarshaler bool
}

type typeUnmarshaler struct {
fn func(types.AttributeValue) (interface{}, error)
in reflect.Value
}

func (u typeUnmarshaler) UnmarshalDynamoDBAttributeValue(av types.AttributeValue) error {
if out, err := u.fn(av); err != nil {
return err
} else if out != nil {
u.in.Set(reflect.ValueOf(out))
}
return nil
}

// indirect will walk a value's interface or pointer value types. Returning
// the final value or the value a unmarshaler is defined on.
//
// Based on the enoding/json type reflect value type indirection in Go Stdlib
// https://golang.org/src/encoding/json/decode.go indirect func.
func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) {
func (d *Decoder) indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) {
// Issue #24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from
Expand Down Expand Up @@ -879,6 +924,11 @@ func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}

if u := d.getTypeUnmarshaler(v); u != nil {
return u, reflect.Value{}
}

if !opts.skipUnmarshaler && v.Type().NumMethod() > 0 && v.CanInterface() {
if u, ok := v.Interface().(Unmarshaler); ok {
return u, reflect.Value{}
Expand All @@ -896,6 +946,19 @@ func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value
return nil, v
}

func (d *Decoder) getTypeUnmarshaler(v reflect.Value) Unmarshaler {
v = valueElem(v)
if v.Kind() == reflect.Invalid {
return nil
}
d.unmarshalersLock.RLock()
defer d.unmarshalersLock.RUnlock()
if fn, ok := d.unmarshalers[v.Type()]; ok {
return typeUnmarshaler{fn, v}
}
return nil
}

// A Number represents a Attributevalue number literal.
type Number string

Expand Down
39 changes: 39 additions & 0 deletions feature/dynamodb/attributevalue/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1172,5 +1172,44 @@ func TestUnmarshalMap_keyPtrTypes(t *testing.T) {
t.Errorf("expect %v key not found", *k)
}
}
}

func TestDecoderTypeUnmarshalers(t *testing.T) {
for name, c := range sharedTypeMarshalersTestCases {
t.Run(name, func(t *testing.T) {
called := false
dec := NewDecoder()
err := dec.RegisterUnmarshaler(reflect.TypeOf(c.expected), func(value types.AttributeValue) (interface{}, error) {
called = true
return c.expected, nil
})
if err != nil {
t.Errorf("expect nil, got %v", err)
}
err = dec.Decode(c.in, c.actual)
if !called {
t.Fatalf("expected unmarshaler to be called")
}
assertConvertTest(t, c.actual, c.expected, err, nil)
})
}
}

func TestDecoderNotSupportedUnmarshalerType(t *testing.T) {
cases := map[string]reflect.Type{
"pointer": reflect.TypeOf(new(string)),
"channel": reflect.TypeOf(make(chan int)),
"func": reflect.TypeOf(func() {}),
}

for name, caseType := range cases {
t.Run(name, func(t *testing.T) {
err := NewDecoder().RegisterUnmarshaler(caseType, func(value types.AttributeValue) (interface{}, error) {
return nil, nil
})
if err == nil {
t.Errorf("expect error when registering unmarshaler for unsupported type %q", caseType)
}
})
}
}
84 changes: 63 additions & 21 deletions feature/dynamodb/attributevalue/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"reflect"
"strconv"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
Expand Down Expand Up @@ -384,7 +385,9 @@ type EncoderOptions struct {

// An Encoder provides marshaling Go value types to AttributeValues.
type Encoder struct {
options EncoderOptions
options EncoderOptions
marshalersLock sync.RWMutex
marshalers map[reflect.Type]func(interface{}) (types.AttributeValue, error)
}

// NewEncoder creates a new Encoder with default configuration. Use
Expand Down Expand Up @@ -414,6 +417,34 @@ func (e *Encoder) Encode(in interface{}) (types.AttributeValue, error) {
return e.encode(reflect.ValueOf(in), tag{})
}

// RegisterMarshaler registers a custom marshaler to use for a provided type.
//
// Precedence is given to registered marshalers that operate on concrete types,
// then the MarshalDynamoDBAttributeValue method, and lastly the default behavior of Encode.
func (e *Encoder) RegisterMarshaler(t reflect.Type, fn func(interface{}) (types.AttributeValue, error)) error {
if t == nil {
return fmt.Errorf("type can't be nil")
}
if fn == nil {
return fmt.Errorf("marshaller function can't be nil")
}
switch t.Kind() {
case reflect.Ptr, reflect.Chan, reflect.Invalid, reflect.Func, reflect.UnsafePointer:
return fmt.Errorf("not supported kind %q", t.Kind())
default:
e.marshalersLock.Lock()
defer e.marshalersLock.Unlock()
if e.marshalers == nil {
e.marshalers = map[reflect.Type]func(interface{}) (types.AttributeValue, error){t: fn}
} else if _, ok := e.marshalers[t]; ok {
return fmt.Errorf("marshaler has already been registered for type %q", t)
} else {
e.marshalers[t] = fn
}
return nil
}
}

func (e *Encoder) encode(v reflect.Value, fieldTag tag) (types.AttributeValue, error) {
// Ignore fields explicitly marked to be skipped.
if fieldTag.Ignore {
Expand All @@ -433,12 +464,11 @@ func (e *Encoder) encode(v reflect.Value, fieldTag tag) (types.AttributeValue, e
return encodeNull(), nil
}
}

// Handle both pointers and interface conversion into types
v = valueElem(v)

if v.Kind() != reflect.Invalid {
if av, err := tryMarshaler(v); err != nil {
if av, err := e.tryMarshaler(v); err != nil {
return nil, err
} else if av != nil {
return av, nil
Expand Down Expand Up @@ -714,7 +744,7 @@ func (e *Encoder) encodeScalar(v reflect.Value, fieldTag tag) (types.AttributeVa
}

func (e *Encoder) encodeNumber(v reflect.Value) (types.AttributeValue, error) {
if av, err := tryMarshaler(v); err != nil {
if av, err := e.tryMarshaler(v); err != nil {
return nil, err
} else if av != nil {
return av, nil
Expand Down Expand Up @@ -742,7 +772,7 @@ func (e *Encoder) encodeNumber(v reflect.Value) (types.AttributeValue, error) {
}

func (e *Encoder) encodeString(v reflect.Value) (types.AttributeValue, error) {
if av, err := tryMarshaler(v); err != nil {
if av, err := e.tryMarshaler(v); err != nil {
return nil, err
} else if av != nil {
return av, nil
Expand All @@ -758,6 +788,34 @@ func (e *Encoder) encodeString(v reflect.Value) (types.AttributeValue, error) {
}
}

func (e *Encoder) tryMarshaler(v reflect.Value) (types.AttributeValue, error) {
if av, err := e.tryTypeMarshaler(v); err != nil {
return nil, err
} else if av != nil {
return av, nil
}
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
v = v.Addr()
}
if v.Type().NumMethod() == 0 {
return nil, nil
}
if m, ok := v.Interface().(Marshaler); ok {
return m.MarshalDynamoDBAttributeValue()
}

return nil, nil
}

func (e *Encoder) tryTypeMarshaler(v reflect.Value) (types.AttributeValue, error) {
e.marshalersLock.RLock()
defer e.marshalersLock.RUnlock()
if m, ok := e.marshalers[v.Type()]; ok {
return m(v.Interface())
}
return nil, nil
}

func encodeInt(i int64) string {
return strconv.FormatInt(i, 10)
}
Expand Down Expand Up @@ -832,22 +890,6 @@ func isNullableZeroValue(v reflect.Value) bool {
return false
}

func tryMarshaler(v reflect.Value) (types.AttributeValue, error) {
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
v = v.Addr()
}

if v.Type().NumMethod() == 0 {
return nil, nil
}

if m, ok := v.Interface().(Marshaler); ok {
return m.MarshalDynamoDBAttributeValue()
}

return nil, nil
}

// An InvalidMarshalError is an error type representing an error
// occurring when marshaling a Go value type to an AttributeValue.
type InvalidMarshalError struct {
Expand Down
40 changes: 40 additions & 0 deletions feature/dynamodb/attributevalue/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,43 @@ func TestMarshalMap_keyTypes(t *testing.T) {
})
}
}

func TestEncoderTypesMarshalers(t *testing.T) {
for name, c := range sharedTypeMarshalersTestCases {
t.Run(name, func(t *testing.T) {
called := false
enc := NewEncoder()
err := enc.RegisterMarshaler(reflect.TypeOf(c.expected), func(i interface{}) (types.AttributeValue, error) {
called = true
return c.in, nil
})
if err != nil {
t.Errorf("expect nil, got %v", err)
}
av, err := enc.Encode(c.expected)
if !called {
t.Fatalf("expected marshaler to be called")
}
assertConvertTest(t, av, c.in, err, nil)
})
}
}

func TestEncoderNotSupportedMarshalerType(t *testing.T) {
cases := map[string]reflect.Type{
"pointer": reflect.TypeOf(new(string)),
"channel": reflect.TypeOf(make(chan int)),
"func": reflect.TypeOf(func() {}),
}

for name, caseType := range cases {
t.Run(name, func(t *testing.T) {
err := NewEncoder().RegisterMarshaler(caseType, func(i interface{}) (types.AttributeValue, error) {
return nil, nil
})
if err == nil {
t.Errorf("expect error when registering marshaler for unsupported type %q", caseType)
}
})
}
}
Loading