From 882127d3c9fdd1caff1cdc82551a108b0b00bbd0 Mon Sep 17 00:00:00 2001 From: Luc Talatinian <102624213+lucix-aws@users.noreply.github.com> Date: Wed, 5 Jun 2024 09:55:13 -0400 Subject: [PATCH] add codec option to use encoding.Text/Binary(Un)Marshaler when present (#2666) --- .../a7a833ea4c9c42bcbecc39d0597c7b88.json | 8 ++ feature/dynamodb/attributevalue/decode.go | 47 +++++++-- .../dynamodb/attributevalue/decode_test.go | 95 +++++++++++++++++++ feature/dynamodb/attributevalue/encode.go | 48 +++++++++- .../dynamodb/attributevalue/encode_test.go | 89 +++++++++++++++++ 5 files changed, 272 insertions(+), 15 deletions(-) create mode 100644 .changelog/a7a833ea4c9c42bcbecc39d0597c7b88.json diff --git a/.changelog/a7a833ea4c9c42bcbecc39d0597c7b88.json b/.changelog/a7a833ea4c9c42bcbecc39d0597c7b88.json new file mode 100644 index 00000000000..354cc54c2f6 --- /dev/null +++ b/.changelog/a7a833ea4c9c42bcbecc39d0597c7b88.json @@ -0,0 +1,8 @@ +{ + "id": "a7a833ea-4c9c-42bc-becc-39d0597c7b88", + "type": "feature", + "description": "Add codec options to use encoding.Text/Binary(Un)Marshaler when present on targets.", + "modules": [ + "feature/dynamodb/attributevalue" + ] +} \ No newline at end of file diff --git a/feature/dynamodb/attributevalue/decode.go b/feature/dynamodb/attributevalue/decode.go index a0a8b53f76f..15a94f4b24d 100644 --- a/feature/dynamodb/attributevalue/decode.go +++ b/feature/dynamodb/attributevalue/decode.go @@ -231,6 +231,18 @@ type DecoderOptions struct { // Default string parsing format is time.RFC3339 // Default number parsing format is seconds since January 1, 1970 UTC DecodeTime DecodeTimeAttributes + + // When enabled, the decoder will use implementations of + // encoding.TextUnmarshaler and encoding.BinaryUnmarshaler when present on + // unmarshaling targets. + // + // If a target implements [Unmarshaler], encoding unmarshaler + // implementations are ignored. + // + // If the attributevalue is a string, its underlying value will be used to + // call UnmarshalText on the target. If the attributevalue is a binary, its + // value will be used to call UnmarshalBinary. + UseEncodingUnmarshalers bool } // A Decoder provides unmarshaling AttributeValues to Go value types. @@ -288,17 +300,30 @@ func (d *Decoder) decode(av types.AttributeValue, v reflect.Value, fieldTag tag) var u Unmarshaler _, isNull := av.(*types.AttributeValueMemberNULL) if av == nil || isNull { - u, v = indirect(v, indirectOptions{decodeNull: true}) + u, v = indirect[Unmarshaler](v, indirectOptions{decodeNull: true}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(av) } return d.decodeNull(v) } - u, v = indirect(v, indirectOptions{}) + v0 := v + u, v = indirect[Unmarshaler](v, indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(av) } + if d.options.UseEncodingUnmarshalers { + if s, ok := av.(*types.AttributeValueMemberS); ok { + if u, _ := indirect[encoding.TextUnmarshaler](v0, indirectOptions{}); u != nil { + return u.UnmarshalText([]byte(s.Value)) + } + } + if b, ok := av.(*types.AttributeValueMemberB); ok { + if u, _ := indirect[encoding.BinaryUnmarshaler](v0, indirectOptions{}); u != nil { + return u.UnmarshalBinary(b.Value) + } + } + } switch tv := av.(type) { case *types.AttributeValueMemberB: @@ -420,7 +445,7 @@ func (d *Decoder) decodeBinarySet(bs [][]byte, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), indirectOptions{}) + u, elem := indirect[Unmarshaler](v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberBS{Value: bs}) } @@ -555,7 +580,7 @@ func (d *Decoder) decodeNumberSet(ns []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), indirectOptions{}) + u, elem := indirect[Unmarshaler](v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberNS{Value: ns}) } @@ -634,7 +659,7 @@ func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Val for k, av := range avMap { key := reflect.New(keyType).Elem() // handle pointer keys - _, indirectKey := indirect(key, indirectOptions{skipUnmarshaler: true}) + _, indirectKey := indirect[Unmarshaler](key, indirectOptions{skipUnmarshaler: true}) if err := decodeMapKey(k, indirectKey, tag{}); err != nil { return &UnmarshalTypeError{ Value: fmt.Sprintf("map key %q", k), @@ -777,7 +802,7 @@ func (d *Decoder) decodeStringSet(ss []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), indirectOptions{}) + u, elem := indirect[Unmarshaler](v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberSS{Value: ss}) } @@ -825,7 +850,7 @@ type indirectOptions struct { // // Based on the enoding/json type reflect value type indirection in Go Stdlib // https://golang.org/src/encoding/json/decode.go indirect func. -func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) { +func indirect[U any](v reflect.Value, opts indirectOptions) (U, reflect.Value) { // Issue #24153 indicates that it is generally not a guaranteed property // that you may round-trip a reflect.Value by calling Value.Addr().Elem() // and expect the value to still be settable for values derived from @@ -859,7 +884,8 @@ func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value continue } if e.Kind() != reflect.Ptr && e.IsValid() { - return nil, e + var u U + return u, e } } if v.Kind() != reflect.Ptr { @@ -880,7 +906,7 @@ func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value v.Set(reflect.New(v.Type().Elem())) } if !opts.skipUnmarshaler && v.Type().NumMethod() > 0 && v.CanInterface() { - if u, ok := v.Interface().(Unmarshaler); ok { + if u, ok := v.Interface().(U); ok { return u, reflect.Value{} } } @@ -893,7 +919,8 @@ func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value } } - return nil, v + var u U + return u, v } // A Number represents a Attributevalue number literal. diff --git a/feature/dynamodb/attributevalue/decode_test.go b/feature/dynamodb/attributevalue/decode_test.go index d8ae9490a85..9481ed5f7e8 100644 --- a/feature/dynamodb/attributevalue/decode_test.go +++ b/feature/dynamodb/attributevalue/decode_test.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "strconv" + "strings" "testing" "time" @@ -1173,3 +1174,97 @@ func TestUnmarshalMap_keyPtrTypes(t *testing.T) { } } + +type textUnmarshalerString string + +func (v *textUnmarshalerString) UnmarshalText(text []byte) error { + *v = textUnmarshalerString("[[" + string(text) + "]]") + return nil +} + +func TestUnmarshalTextString(t *testing.T) { + in := &types.AttributeValueMemberS{Value: "foo"} + + var actual textUnmarshalerString + err := UnmarshalWithOptions(in, &actual, func(o *DecoderOptions) { + o.UseEncodingUnmarshalers = true + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if string(actual) != "[[foo]]" { + t.Errorf("expected [[foo]], got %s", actual) + } +} + +func TestUnmarshalTextStringDisabled(t *testing.T) { + in := &types.AttributeValueMemberS{Value: "foo"} + + var actual textUnmarshalerString + err := UnmarshalWithOptions(in, &actual, func(o *DecoderOptions) { + o.UseEncodingUnmarshalers = false + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if string(actual) != "foo" { + t.Errorf("expected foo, got %s", actual) + } +} + +type textUnmarshalerStruct struct { + I, J string +} + +func (v *textUnmarshalerStruct) UnmarshalText(text []byte) error { + parts := strings.Split(string(text), ";") + v.I = parts[0] + v.J = parts[1] + return nil +} + +func TestUnmarshalTextStruct(t *testing.T) { + in := &types.AttributeValueMemberS{Value: "foo;bar"} + + var actual textUnmarshalerStruct + err := UnmarshalWithOptions(in, &actual, func(o *DecoderOptions) { + o.UseEncodingUnmarshalers = true + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + expected := textUnmarshalerStruct{"foo", "bar"} + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } +} + +type binaryUnmarshaler struct { + I, J byte +} + +func (v *binaryUnmarshaler) UnmarshalBinary(b []byte) error { + v.I = b[0] + v.J = b[1] + return nil +} + +func TestUnmarshalBinary(t *testing.T) { + in := &types.AttributeValueMemberB{Value: []byte{1, 2}} + + var actual binaryUnmarshaler + err := UnmarshalWithOptions(in, &actual, func(o *DecoderOptions) { + o.UseEncodingUnmarshalers = true + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + expected := binaryUnmarshaler{1, 2} + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } +} diff --git a/feature/dynamodb/attributevalue/encode.go b/feature/dynamodb/attributevalue/encode.go index f62000a68f0..005a23c3b01 100644 --- a/feature/dynamodb/attributevalue/encode.go +++ b/feature/dynamodb/attributevalue/encode.go @@ -354,8 +354,7 @@ func MarshalListWithOptions(in interface{}, optFns ...func(*EncoderOptions)) ([] return asList.Value, nil } -// EncoderOptions is a collection of options shared between marshaling -// and unmarshaling +// EncoderOptions is a collection of options used by the marshaler. type EncoderOptions struct { // Support other custom struct tag keys, such as `yaml`, `json`, or `toml`. // Note that values provided with a custom TagKey must also be supported @@ -380,6 +379,19 @@ type EncoderOptions struct { // // Default encoding is time.RFC3339Nano in a DynamoDB String (S) data type. EncodeTime func(time.Time) (types.AttributeValue, error) + + // When enabled, the encoder will use implementations of + // encoding.TextMarshaler and encoding.BinaryMarshaler when present on + // marshaled values. + // + // Implementations are checked in the following order: + // - [Marshaler] + // - encoding.TextMarshaler + // - encoding.BinaryMarshaler + // + // The results of a MarshalText call will convert to string (S), results + // from a MarshalBinary call will convert to binary (B). + UseEncodingMarshalers bool } // An Encoder provides marshaling Go value types to AttributeValues. @@ -438,7 +450,7 @@ func (e *Encoder) encode(v reflect.Value, fieldTag tag) (types.AttributeValue, e v = valueElem(v) if v.Kind() != reflect.Invalid { - if av, err := tryMarshaler(v); err != nil { + if av, err := e.tryMarshaler(v); err != nil { return nil, err } else if av != nil { return av, nil @@ -822,7 +834,7 @@ func isNullableZeroValue(v reflect.Value) bool { return false } -func tryMarshaler(v reflect.Value) (types.AttributeValue, error) { +func (e *Encoder) tryMarshaler(v reflect.Value) (types.AttributeValue, error) { if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { v = v.Addr() } @@ -831,9 +843,35 @@ func tryMarshaler(v reflect.Value) (types.AttributeValue, error) { return nil, nil } - if m, ok := v.Interface().(Marshaler); ok { + i := v.Interface() + if m, ok := i.(Marshaler); ok { return m.MarshalDynamoDBAttributeValue() } + if e.options.UseEncodingMarshalers { + return e.tryEncodingMarshaler(i) + } + + return nil, nil +} + +func (e *Encoder) tryEncodingMarshaler(v any) (types.AttributeValue, error) { + if m, ok := v.(encoding.TextMarshaler); ok { + s, err := m.MarshalText() + if err != nil { + return nil, err + } + + return &types.AttributeValueMemberS{Value: string(s)}, nil + } + + if m, ok := v.(encoding.BinaryMarshaler); ok { + b, err := m.MarshalBinary() + if err != nil { + return nil, err + } + + return &types.AttributeValueMemberB{Value: b}, nil + } return nil, nil } diff --git a/feature/dynamodb/attributevalue/encode_test.go b/feature/dynamodb/attributevalue/encode_test.go index 342550a6c03..97a385c6679 100644 --- a/feature/dynamodb/attributevalue/encode_test.go +++ b/feature/dynamodb/attributevalue/encode_test.go @@ -94,6 +94,95 @@ func (m customBoolStringMarshaler) MarshalDynamoDBAttributeValue() (types.Attrib return &types.AttributeValueMemberS{Value: string(m)}, nil } +type customTextMarshaler struct { + I, J int +} + +func (v customTextMarshaler) MarshalText() ([]byte, error) { + text := fmt.Sprintf("{I: %d, J: %d}", v.I, v.J) + return []byte(text), nil +} + +type customBinaryMarshaler struct { + I, J byte +} + +func (v customBinaryMarshaler) MarshalBinary() ([]byte, error) { + return []byte{v.I, v.J}, nil +} + +type customAVAndTextMarshaler struct { + I, J int +} + +func (v customAVAndTextMarshaler) MarshalDynamoDBAttributeValue() (types.AttributeValue, error) { + return &types.AttributeValueMemberNS{Value: []string{ + fmt.Sprintf("%d", v.I), + fmt.Sprintf("%d", v.J), + }}, nil +} + +func (v customAVAndTextMarshaler) MarshalText() ([]byte, error) { + return []byte("should never happen"), nil +} + +func TestEncodingMarshalers(t *testing.T) { + cases := []struct { + input any + expected types.AttributeValue + useMarshalers bool + }{ + { + input: customTextMarshaler{1, 2}, + expected: &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "I": &types.AttributeValueMemberN{Value: "1"}, + "J": &types.AttributeValueMemberN{Value: "2"}, + }}, + useMarshalers: false, + }, + { + input: customTextMarshaler{1, 2}, + expected: &types.AttributeValueMemberS{Value: "{I: 1, J: 2}"}, + useMarshalers: true, + }, + { + input: customBinaryMarshaler{1, 2}, + expected: &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "I": &types.AttributeValueMemberN{Value: "1"}, + "J": &types.AttributeValueMemberN{Value: "2"}, + }}, + useMarshalers: false, + }, + { + input: customBinaryMarshaler{1, 2}, + expected: &types.AttributeValueMemberB{Value: []byte{1, 2}}, + useMarshalers: true, + }, + { + input: customAVAndTextMarshaler{1, 2}, + expected: &types.AttributeValueMemberNS{Value: []string{"1", "2"}}, + useMarshalers: false, + }, + { + input: customAVAndTextMarshaler{1, 2}, + expected: &types.AttributeValueMemberNS{Value: []string{"1", "2"}}, + useMarshalers: true, + }, + } + + for _, testCase := range cases { + actual, err := MarshalWithOptions(testCase.input, func(o *EncoderOptions) { + o.UseEncodingMarshalers = testCase.useMarshalers + }) + if err != nil { + t.Errorf("got unexpected error %v for input %v", err, testCase.input) + } + if diff := cmpDiff(testCase.expected, actual); len(diff) != 0 { + t.Errorf("expected match but got: %s", diff) + } + } +} + func TestCustomStringMarshaler(t *testing.T) { cases := []struct { expected types.AttributeValue