Skip to content

Commit

Permalink
add codec option to use encoding.Text/Binary(Un)Marshaler when present (
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws authored Jun 5, 2024
1 parent 5b30240 commit 882127d
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 15 deletions.
8 changes: 8 additions & 0 deletions .changelog/a7a833ea4c9c42bcbecc39d0597c7b88.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "a7a833ea-4c9c-42bc-becc-39d0597c7b88",
"type": "feature",
"description": "Add codec options to use encoding.Text/Binary(Un)Marshaler when present on targets.",
"modules": [
"feature/dynamodb/attributevalue"
]
}
47 changes: 37 additions & 10 deletions feature/dynamodb/attributevalue/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,18 @@ type DecoderOptions struct {
// Default string parsing format is time.RFC3339
// Default number parsing format is seconds since January 1, 1970 UTC
DecodeTime DecodeTimeAttributes

// When enabled, the decoder will use implementations of
// encoding.TextUnmarshaler and encoding.BinaryUnmarshaler when present on
// unmarshaling targets.
//
// If a target implements [Unmarshaler], encoding unmarshaler
// implementations are ignored.
//
// If the attributevalue is a string, its underlying value will be used to
// call UnmarshalText on the target. If the attributevalue is a binary, its
// value will be used to call UnmarshalBinary.
UseEncodingUnmarshalers bool
}

// A Decoder provides unmarshaling AttributeValues to Go value types.
Expand Down Expand Up @@ -288,17 +300,30 @@ func (d *Decoder) decode(av types.AttributeValue, v reflect.Value, fieldTag tag)
var u Unmarshaler
_, isNull := av.(*types.AttributeValueMemberNULL)
if av == nil || isNull {
u, v = indirect(v, indirectOptions{decodeNull: true})
u, v = indirect[Unmarshaler](v, indirectOptions{decodeNull: true})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(av)
}
return d.decodeNull(v)
}

u, v = indirect(v, indirectOptions{})
v0 := v
u, v = indirect[Unmarshaler](v, indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(av)
}
if d.options.UseEncodingUnmarshalers {
if s, ok := av.(*types.AttributeValueMemberS); ok {
if u, _ := indirect[encoding.TextUnmarshaler](v0, indirectOptions{}); u != nil {
return u.UnmarshalText([]byte(s.Value))
}
}
if b, ok := av.(*types.AttributeValueMemberB); ok {
if u, _ := indirect[encoding.BinaryUnmarshaler](v0, indirectOptions{}); u != nil {
return u.UnmarshalBinary(b.Value)
}
}
}

switch tv := av.(type) {
case *types.AttributeValueMemberB:
Expand Down Expand Up @@ -420,7 +445,7 @@ func (d *Decoder) decodeBinarySet(bs [][]byte, v reflect.Value) error {
if !isArray {
v.SetLen(i + 1)
}
u, elem := indirect(v.Index(i), indirectOptions{})
u, elem := indirect[Unmarshaler](v.Index(i), indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberBS{Value: bs})
}
Expand Down Expand Up @@ -555,7 +580,7 @@ func (d *Decoder) decodeNumberSet(ns []string, v reflect.Value) error {
if !isArray {
v.SetLen(i + 1)
}
u, elem := indirect(v.Index(i), indirectOptions{})
u, elem := indirect[Unmarshaler](v.Index(i), indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberNS{Value: ns})
}
Expand Down Expand Up @@ -634,7 +659,7 @@ func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Val
for k, av := range avMap {
key := reflect.New(keyType).Elem()
// handle pointer keys
_, indirectKey := indirect(key, indirectOptions{skipUnmarshaler: true})
_, indirectKey := indirect[Unmarshaler](key, indirectOptions{skipUnmarshaler: true})
if err := decodeMapKey(k, indirectKey, tag{}); err != nil {
return &UnmarshalTypeError{
Value: fmt.Sprintf("map key %q", k),
Expand Down Expand Up @@ -777,7 +802,7 @@ func (d *Decoder) decodeStringSet(ss []string, v reflect.Value) error {
if !isArray {
v.SetLen(i + 1)
}
u, elem := indirect(v.Index(i), indirectOptions{})
u, elem := indirect[Unmarshaler](v.Index(i), indirectOptions{})
if u != nil {
return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberSS{Value: ss})
}
Expand Down Expand Up @@ -825,7 +850,7 @@ type indirectOptions struct {
//
// Based on the enoding/json type reflect value type indirection in Go Stdlib
// https://golang.org/src/encoding/json/decode.go indirect func.
func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) {
func indirect[U any](v reflect.Value, opts indirectOptions) (U, reflect.Value) {
// Issue #24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from
Expand Down Expand Up @@ -859,7 +884,8 @@ func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value
continue
}
if e.Kind() != reflect.Ptr && e.IsValid() {
return nil, e
var u U
return u, e
}
}
if v.Kind() != reflect.Ptr {
Expand All @@ -880,7 +906,7 @@ func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value
v.Set(reflect.New(v.Type().Elem()))
}
if !opts.skipUnmarshaler && v.Type().NumMethod() > 0 && v.CanInterface() {
if u, ok := v.Interface().(Unmarshaler); ok {
if u, ok := v.Interface().(U); ok {
return u, reflect.Value{}
}
}
Expand All @@ -893,7 +919,8 @@ func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value
}
}

return nil, v
var u U
return u, v
}

// A Number represents a Attributevalue number literal.
Expand Down
95 changes: 95 additions & 0 deletions feature/dynamodb/attributevalue/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -1173,3 +1174,97 @@ func TestUnmarshalMap_keyPtrTypes(t *testing.T) {
}

}

type textUnmarshalerString string

func (v *textUnmarshalerString) UnmarshalText(text []byte) error {
*v = textUnmarshalerString("[[" + string(text) + "]]")
return nil
}

func TestUnmarshalTextString(t *testing.T) {
in := &types.AttributeValueMemberS{Value: "foo"}

var actual textUnmarshalerString
err := UnmarshalWithOptions(in, &actual, func(o *DecoderOptions) {
o.UseEncodingUnmarshalers = true
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if string(actual) != "[[foo]]" {
t.Errorf("expected [[foo]], got %s", actual)
}
}

func TestUnmarshalTextStringDisabled(t *testing.T) {
in := &types.AttributeValueMemberS{Value: "foo"}

var actual textUnmarshalerString
err := UnmarshalWithOptions(in, &actual, func(o *DecoderOptions) {
o.UseEncodingUnmarshalers = false
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if string(actual) != "foo" {
t.Errorf("expected foo, got %s", actual)
}
}

type textUnmarshalerStruct struct {
I, J string
}

func (v *textUnmarshalerStruct) UnmarshalText(text []byte) error {
parts := strings.Split(string(text), ";")
v.I = parts[0]
v.J = parts[1]
return nil
}

func TestUnmarshalTextStruct(t *testing.T) {
in := &types.AttributeValueMemberS{Value: "foo;bar"}

var actual textUnmarshalerStruct
err := UnmarshalWithOptions(in, &actual, func(o *DecoderOptions) {
o.UseEncodingUnmarshalers = true
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

expected := textUnmarshalerStruct{"foo", "bar"}
if actual != expected {
t.Errorf("expected %v, got %v", expected, actual)
}
}

type binaryUnmarshaler struct {
I, J byte
}

func (v *binaryUnmarshaler) UnmarshalBinary(b []byte) error {
v.I = b[0]
v.J = b[1]
return nil
}

func TestUnmarshalBinary(t *testing.T) {
in := &types.AttributeValueMemberB{Value: []byte{1, 2}}

var actual binaryUnmarshaler
err := UnmarshalWithOptions(in, &actual, func(o *DecoderOptions) {
o.UseEncodingUnmarshalers = true
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

expected := binaryUnmarshaler{1, 2}
if actual != expected {
t.Errorf("expected %v, got %v", expected, actual)
}
}
48 changes: 43 additions & 5 deletions feature/dynamodb/attributevalue/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,7 @@ func MarshalListWithOptions(in interface{}, optFns ...func(*EncoderOptions)) ([]
return asList.Value, nil
}

// EncoderOptions is a collection of options shared between marshaling
// and unmarshaling
// EncoderOptions is a collection of options used by the marshaler.
type EncoderOptions struct {
// Support other custom struct tag keys, such as `yaml`, `json`, or `toml`.
// Note that values provided with a custom TagKey must also be supported
Expand All @@ -380,6 +379,19 @@ type EncoderOptions struct {
//
// Default encoding is time.RFC3339Nano in a DynamoDB String (S) data type.
EncodeTime func(time.Time) (types.AttributeValue, error)

// When enabled, the encoder will use implementations of
// encoding.TextMarshaler and encoding.BinaryMarshaler when present on
// marshaled values.
//
// Implementations are checked in the following order:
// - [Marshaler]
// - encoding.TextMarshaler
// - encoding.BinaryMarshaler
//
// The results of a MarshalText call will convert to string (S), results
// from a MarshalBinary call will convert to binary (B).
UseEncodingMarshalers bool
}

// An Encoder provides marshaling Go value types to AttributeValues.
Expand Down Expand Up @@ -438,7 +450,7 @@ func (e *Encoder) encode(v reflect.Value, fieldTag tag) (types.AttributeValue, e
v = valueElem(v)

if v.Kind() != reflect.Invalid {
if av, err := tryMarshaler(v); err != nil {
if av, err := e.tryMarshaler(v); err != nil {
return nil, err
} else if av != nil {
return av, nil
Expand Down Expand Up @@ -822,7 +834,7 @@ func isNullableZeroValue(v reflect.Value) bool {
return false
}

func tryMarshaler(v reflect.Value) (types.AttributeValue, error) {
func (e *Encoder) tryMarshaler(v reflect.Value) (types.AttributeValue, error) {
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
v = v.Addr()
}
Expand All @@ -831,9 +843,35 @@ func tryMarshaler(v reflect.Value) (types.AttributeValue, error) {
return nil, nil
}

if m, ok := v.Interface().(Marshaler); ok {
i := v.Interface()
if m, ok := i.(Marshaler); ok {
return m.MarshalDynamoDBAttributeValue()
}
if e.options.UseEncodingMarshalers {
return e.tryEncodingMarshaler(i)
}

return nil, nil
}

func (e *Encoder) tryEncodingMarshaler(v any) (types.AttributeValue, error) {
if m, ok := v.(encoding.TextMarshaler); ok {
s, err := m.MarshalText()
if err != nil {
return nil, err
}

return &types.AttributeValueMemberS{Value: string(s)}, nil
}

if m, ok := v.(encoding.BinaryMarshaler); ok {
b, err := m.MarshalBinary()
if err != nil {
return nil, err
}

return &types.AttributeValueMemberB{Value: b}, nil
}

return nil, nil
}
Expand Down
Loading

0 comments on commit 882127d

Please sign in to comment.