From 83e9c2bff8815996694d964b0aa7df0576fa82ed Mon Sep 17 00:00:00 2001 From: Ben Luddy Date: Thu, 25 Jan 2024 09:52:34 -0500 Subject: [PATCH] Support auto conversion of byte strings to and from text encodings. These options improve interoperability with programs that use JSON to encode and decode objects to and from both struct types and empty interface values. Signed-off-by: Ben Luddy --- decode.go | 356 +++++++++++++++++++++++++++++--------- decode_test.go | 460 ++++++++++++++++++++++++++++++++++++++++++++++--- encode.go | 145 ++++++++++++---- encode_test.go | 201 +++++++++++++++++++++ json_test.go | 132 ++++++++++++++ 5 files changed, 1155 insertions(+), 139 deletions(-) create mode 100644 json_test.go diff --git a/decode.go b/decode.go index c3269986..48c8f6fc 100644 --- a/decode.go +++ b/decode.go @@ -5,7 +5,9 @@ package cbor import ( "encoding" + "encoding/base64" "encoding/binary" + "encoding/hex" "errors" "fmt" "io" @@ -440,6 +442,13 @@ const ( // ByteStringToStringAllowed permits decoding a CBOR byte string into a Go string. ByteStringToStringAllowed + // ByteStringToStringAllowedWithExpectedLaterEncoding permits decoding a CBOR byte string + // into a Go string. Also, if the byte string is enclosed (directly or indirectly) by one of + // the "expected later encoding" tags (numbers 21 through 23), the destination string will + // be populated by applying the designated text encoding to the contents of the input byte + // string. + ByteStringToStringAllowedWithExpectedLaterEncoding + maxByteStringToStringMode ) @@ -593,6 +602,34 @@ func (bttm ByteStringToTimeMode) valid() bool { return bttm >= 0 && bttm < maxByteStringToTimeMode } +// ByteSliceExpectedEncodingMode specifies how to decode a byte string NOT enclosed in an "expected +// later encoding" tag (RFC 8949 Section 3.4.5.2) into a Go byte slice. +type ByteSliceExpectedEncodingMode int + +const ( + // ByteSliceExpectedEncodingIgnored copies the contents of the byte string, unmodified, into + // a destination Go byte slice. + ByteSliceExpectedEncodingIgnored = iota + + // ByteSliceExpectedEncodingBase64URL assumes that byte strings with no text encoding hint + // contain base64url-encoded bytes. + ByteSliceExpectedEncodingBase64URL + + // ByteSliceExpectedEncodingBase64 assumes that byte strings with no text encoding hint + // contain base64-encoded bytes. + ByteSliceExpectedEncodingBase64 + + // ByteSliceExpectedEncodingBase16 assumes that byte strings with no text encoding hint + // contain base16-encoded bytes. + ByteSliceExpectedEncodingBase16 + + maxByteSliceExpectedEncodingMode +) + +func (bseem ByteSliceExpectedEncodingMode) valid() bool { + return bseem >= 0 && bseem < maxByteSliceExpectedEncodingMode +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -710,6 +747,10 @@ type DecOptions struct { // ByteStringToTimeMode specifies the behavior when decoding a CBOR byte string into a Go time.Time. ByteStringToTime ByteStringToTimeMode + + // ByteSliceExpectedEncodingMode specifies how to decode a byte string NOT enclosed in an + // "expected later encoding" tag (RFC 8949 Section 3.4.5.2) into a Go byte slice. + ByteSliceExpectedEncoding ByteSliceExpectedEncodingMode } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -717,15 +758,35 @@ func (opts DecOptions) DecMode() (DecMode, error) { return opts.decMode() } -// DecModeWithTags returns DecMode with options and tags that are both immutable (safe for concurrency). -func (opts DecOptions) DecModeWithTags(tags TagSet) (DecMode, error) { +// validForTags checks that the provided tag set is compatible with these options and returns a +// non-nil error if and only if the provided tag set is incompatible. +func (opts DecOptions) validForTags(tags TagSet) error { if opts.TagsMd == TagsForbidden { - return nil, errors.New("cbor: cannot create DecMode with TagSet when TagsMd is TagsForbidden") + return errors.New("cbor: cannot create DecMode with TagSet when TagsMd is TagsForbidden") } if tags == nil { - return nil, errors.New("cbor: cannot create DecMode with nil value as TagSet") + return errors.New("cbor: cannot create DecMode with nil value as TagSet") + } + if opts.ByteStringToString == ByteStringToStringAllowedWithExpectedLaterEncoding || opts.ByteSliceExpectedEncoding != ByteSliceExpectedEncodingIgnored { + for _, tagNum := range []uint64{ + expectedLaterEncodingBase64URLTagNum, + expectedLaterEncodingBase64TagNum, + expectedLaterEncodingBase16TagNum, + } { + if rt := tags.getTypeFromTagNum([]uint64{tagNum}); rt != nil { + return fmt.Errorf("cbor: DecMode with non-default StringExpectedEncoding or ByteSliceExpectedEncoding treats tag %d as built-in and conflicts with the provided TagSet's registration of %v", tagNum, rt) + } + } + } + return nil +} +// DecModeWithTags returns DecMode with options and tags that are both immutable (safe for concurrency). +func (opts DecOptions) DecModeWithTags(tags TagSet) (DecMode, error) { + if err := opts.validForTags(tags); err != nil { + return nil, err + } dm, err := opts.decMode() if err != nil { return nil, err @@ -751,11 +812,8 @@ func (opts DecOptions) DecModeWithTags(tags TagSet) (DecMode, error) { // DecModeWithSharedTags returns DecMode with immutable options and mutable shared tags (safe for concurrency). func (opts DecOptions) DecModeWithSharedTags(tags TagSet) (DecMode, error) { - if opts.TagsMd == TagsForbidden { - return nil, errors.New("cbor: cannot create DecMode with TagSet when TagsMd is TagsForbidden") - } - if tags == nil { - return nil, errors.New("cbor: cannot create DecMode with nil value as TagSet") + if err := opts.validForTags(tags); err != nil { + return nil, err } dm, err := opts.decMode() if err != nil { @@ -892,30 +950,35 @@ func (opts DecOptions) decMode() (*decMode, error) { return nil, errors.New("cbor: invalid ByteStringToTime " + strconv.Itoa(int(opts.ByteStringToTime))) } + if !opts.ByteSliceExpectedEncoding.valid() { + return nil, errors.New("cbor: invalid ByteSliceExpectedEncoding " + strconv.Itoa(int(opts.ByteSliceExpectedEncoding))) + } + dm := decMode{ - dupMapKey: opts.DupMapKey, - timeTag: opts.TimeTag, - maxNestedLevels: opts.MaxNestedLevels, - maxArrayElements: opts.MaxArrayElements, - maxMapPairs: opts.MaxMapPairs, - indefLength: opts.IndefLength, - tagsMd: opts.TagsMd, - intDec: opts.IntDec, - mapKeyByteString: opts.MapKeyByteString, - extraReturnErrors: opts.ExtraReturnErrors, - defaultMapType: opts.DefaultMapType, - utf8: opts.UTF8, - fieldNameMatching: opts.FieldNameMatching, - bigIntDec: opts.BigIntDec, - defaultByteStringType: opts.DefaultByteStringType, - byteStringToString: opts.ByteStringToString, - fieldNameByteString: opts.FieldNameByteString, - unrecognizedTagToAny: opts.UnrecognizedTagToAny, - timeTagToAny: opts.TimeTagToAny, - simpleValues: simpleValues, - nanDec: opts.NaN, - infDec: opts.Inf, - byteStringToTime: opts.ByteStringToTime, + dupMapKey: opts.DupMapKey, + timeTag: opts.TimeTag, + maxNestedLevels: opts.MaxNestedLevels, + maxArrayElements: opts.MaxArrayElements, + maxMapPairs: opts.MaxMapPairs, + indefLength: opts.IndefLength, + tagsMd: opts.TagsMd, + intDec: opts.IntDec, + mapKeyByteString: opts.MapKeyByteString, + extraReturnErrors: opts.ExtraReturnErrors, + defaultMapType: opts.DefaultMapType, + utf8: opts.UTF8, + fieldNameMatching: opts.FieldNameMatching, + bigIntDec: opts.BigIntDec, + defaultByteStringType: opts.DefaultByteStringType, + byteStringToString: opts.ByteStringToString, + fieldNameByteString: opts.FieldNameByteString, + unrecognizedTagToAny: opts.UnrecognizedTagToAny, + timeTagToAny: opts.TimeTagToAny, + simpleValues: simpleValues, + nanDec: opts.NaN, + infDec: opts.Inf, + byteStringToTime: opts.ByteStringToTime, + byteSliceExpectedEncoding: opts.ByteSliceExpectedEncoding, } return &dm, nil @@ -968,30 +1031,31 @@ type DecMode interface { } type decMode struct { - tags tagProvider - dupMapKey DupMapKeyMode - timeTag DecTagMode - maxNestedLevels int - maxArrayElements int - maxMapPairs int - indefLength IndefLengthMode - tagsMd TagsMode - intDec IntDecMode - mapKeyByteString MapKeyByteStringMode - extraReturnErrors ExtraDecErrorCond - defaultMapType reflect.Type - utf8 UTF8Mode - fieldNameMatching FieldNameMatchingMode - bigIntDec BigIntDecMode - defaultByteStringType reflect.Type - byteStringToString ByteStringToStringMode - fieldNameByteString FieldNameByteStringMode - unrecognizedTagToAny UnrecognizedTagToAnyMode - timeTagToAny TimeTagToAnyMode - simpleValues *SimpleValueRegistry - nanDec NaNMode - infDec InfMode - byteStringToTime ByteStringToTimeMode + tags tagProvider + dupMapKey DupMapKeyMode + timeTag DecTagMode + maxNestedLevels int + maxArrayElements int + maxMapPairs int + indefLength IndefLengthMode + tagsMd TagsMode + intDec IntDecMode + mapKeyByteString MapKeyByteStringMode + extraReturnErrors ExtraDecErrorCond + defaultMapType reflect.Type + utf8 UTF8Mode + fieldNameMatching FieldNameMatchingMode + bigIntDec BigIntDecMode + defaultByteStringType reflect.Type + byteStringToString ByteStringToStringMode + fieldNameByteString FieldNameByteStringMode + unrecognizedTagToAny UnrecognizedTagToAnyMode + timeTagToAny TimeTagToAnyMode + simpleValues *SimpleValueRegistry + nanDec NaNMode + infDec InfMode + byteStringToTime ByteStringToTimeMode + byteSliceExpectedEncoding ByteSliceExpectedEncodingMode } var defaultDecMode, _ = DecOptions{}.decMode() @@ -1006,29 +1070,30 @@ func (dm *decMode) DecOptions() DecOptions { } return DecOptions{ - DupMapKey: dm.dupMapKey, - TimeTag: dm.timeTag, - MaxNestedLevels: dm.maxNestedLevels, - MaxArrayElements: dm.maxArrayElements, - MaxMapPairs: dm.maxMapPairs, - IndefLength: dm.indefLength, - TagsMd: dm.tagsMd, - IntDec: dm.intDec, - MapKeyByteString: dm.mapKeyByteString, - ExtraReturnErrors: dm.extraReturnErrors, - DefaultMapType: dm.defaultMapType, - UTF8: dm.utf8, - FieldNameMatching: dm.fieldNameMatching, - BigIntDec: dm.bigIntDec, - DefaultByteStringType: dm.defaultByteStringType, - ByteStringToString: dm.byteStringToString, - FieldNameByteString: dm.fieldNameByteString, - UnrecognizedTagToAny: dm.unrecognizedTagToAny, - TimeTagToAny: dm.timeTagToAny, - SimpleValues: simpleValues, - NaN: dm.nanDec, - Inf: dm.infDec, - ByteStringToTime: dm.byteStringToTime, + DupMapKey: dm.dupMapKey, + TimeTag: dm.timeTag, + MaxNestedLevels: dm.maxNestedLevels, + MaxArrayElements: dm.maxArrayElements, + MaxMapPairs: dm.maxMapPairs, + IndefLength: dm.indefLength, + TagsMd: dm.tagsMd, + IntDec: dm.intDec, + MapKeyByteString: dm.mapKeyByteString, + ExtraReturnErrors: dm.extraReturnErrors, + DefaultMapType: dm.defaultMapType, + UTF8: dm.utf8, + FieldNameMatching: dm.fieldNameMatching, + BigIntDec: dm.bigIntDec, + DefaultByteStringType: dm.defaultByteStringType, + ByteStringToString: dm.byteStringToString, + FieldNameByteString: dm.fieldNameByteString, + UnrecognizedTagToAny: dm.unrecognizedTagToAny, + TimeTagToAny: dm.timeTagToAny, + SimpleValues: simpleValues, + NaN: dm.nanDec, + Inf: dm.infDec, + ByteStringToTime: dm.byteStringToTime, + ByteSliceExpectedEncoding: dm.byteSliceExpectedEncoding, } } @@ -1116,6 +1181,17 @@ type decoder struct { data []byte off int // next read offset in data dm *decMode + + // expectedLaterEncodingTags stores a stack of encountered "Expected Later Encoding" tags, + // if any. + // + // The "Expected Later Encoding" tags (21 to 23) are valid for any data item. When decoding + // byte strings, the effective encoding comes from the tag nearest to the byte string being + // decoded. For example, the effective encoding of the byte string 21(22(h'41')) would be + // controlled by tag 22,and in the data item 23(h'42', 22([21(h'43')])]) the effective + // encoding of the byte strings h'42' and h'43' would be controlled by tag 23 and 21, + // respectively. + expectedLaterEncodingTags []uint64 } // value decodes CBOR data item into the value pointed to by v. @@ -1174,7 +1250,10 @@ func (t cborType) String() string { } const ( - selfDescribedCBORTagNum = 55799 + selfDescribedCBORTagNum = 55799 + expectedLaterEncodingBase64URLTagNum = 21 + expectedLaterEncodingBase64TagNum = 22 + expectedLaterEncodingBase16TagNum = 23 ) // parseToValue decodes CBOR data to value. It assumes data is well-formed, @@ -1329,6 +1408,11 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin case cborTypeByteString: b, copied := d.parseByteString() + b, converted, err := d.applyByteStringTextConversion(b, v.Type()) + if err != nil { + return err + } + copied = copied || converted return fillByteString(t, b, !copied, v, d.dm.byteStringToString) case cborTypeTextString: @@ -1413,6 +1497,15 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin GoType: tInfo.nonPtrType.String(), errorMsg: bi.String() + " overflows " + v.Type().String(), } + case expectedLaterEncodingBase64URLTagNum, expectedLaterEncodingBase64TagNum, expectedLaterEncodingBase16TagNum: + // If conversion for interoperability with text encodings is not configured, + // treat tags 21-23 as unregistered tags. + if d.dm.byteStringToString == ByteStringToStringAllowedWithExpectedLaterEncoding || d.dm.byteSliceExpectedEncoding != ByteSliceExpectedEncodingIgnored { + d.expectedLaterEncodingTags = append(d.expectedLaterEncodingTags, tagNum) + defer func() { + d.expectedLaterEncodingTags = d.expectedLaterEncodingTags[:len(d.expectedLaterEncodingTags)-1] + }() + } } return d.parseToValue(v, tInfo) @@ -1679,9 +1772,19 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli return nValue, nil case cborTypeByteString: - switch d.dm.defaultByteStringType { - case nil, typeByteSlice: - b, copied := d.parseByteString() + b, copied := d.parseByteString() + var effectiveByteStringType = d.dm.defaultByteStringType + if effectiveByteStringType == nil { + effectiveByteStringType = typeByteSlice + } + b, converted, err := d.applyByteStringTextConversion(b, effectiveByteStringType) + if err != nil { + return nil, err + } + copied = copied || converted + + switch effectiveByteStringType { + case typeByteSlice: if copied { return b, nil } @@ -1689,10 +1792,8 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli copy(clone, b) return clone, nil case typeString: - b, _ := d.parseByteString() return string(b), nil default: - b, copied := d.parseByteString() if copied || d.dm.defaultByteStringType.Kind() == reflect.String { // Avoid an unnecessary copy since the conversion to string must // copy the underlying bytes. @@ -1754,6 +1855,16 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli return bi, nil } return *bi, nil + case expectedLaterEncodingBase64URLTagNum, expectedLaterEncodingBase64TagNum, expectedLaterEncodingBase16TagNum: + // If conversion for interoperability with text encodings is not configured, + // treat tags 21-23 as unregistered tags. + if d.dm.byteStringToString == ByteStringToStringAllowedWithExpectedLaterEncoding || d.dm.byteSliceExpectedEncoding != ByteSliceExpectedEncodingIgnored { + d.expectedLaterEncodingTags = append(d.expectedLaterEncodingTags, tagNum) + defer func() { + d.expectedLaterEncodingTags = d.expectedLaterEncodingTags[:len(d.expectedLaterEncodingTags)-1] + }() + return d.parse(false) + } } if d.dm.tags != nil { @@ -1847,6 +1958,71 @@ func (d *decoder) parseByteString() ([]byte, bool) { return b, true } +// applyByteStringTextConversion converts bytes read from a byte string to or from a configured text +// encoding. If no transformation was performed (because it was not required), the original byte +// slice is returned and the bool return value is false. Otherwise, a new slice containing the +// converted bytes is returned along with the bool value true. +func (d *decoder) applyByteStringTextConversion(src []byte, dstType reflect.Type) ([]byte, bool, error) { + switch dstType.Kind() { + case reflect.String: + if d.dm.byteStringToString != ByteStringToStringAllowedWithExpectedLaterEncoding || len(d.expectedLaterEncodingTags) == 0 { + return src, false, nil + } + + switch d.expectedLaterEncodingTags[len(d.expectedLaterEncodingTags)-1] { + case expectedLaterEncodingBase64URLTagNum: + encoded := make([]byte, base64.RawURLEncoding.EncodedLen(len(src))) + base64.RawURLEncoding.Encode(encoded, src) + return encoded, true, nil + case expectedLaterEncodingBase64TagNum: + encoded := make([]byte, base64.StdEncoding.EncodedLen(len(src))) + base64.StdEncoding.Encode(encoded, src) + return encoded, true, nil + case expectedLaterEncodingBase16TagNum: + encoded := make([]byte, hex.EncodedLen(len(src))) + hex.Encode(encoded, src) + return encoded, true, nil + default: + // If this happens, there is a bug: the decoder has pushed an invalid + // "expected later encoding" tag to the stack. + panic(fmt.Sprintf("unrecognized expected later encoding tag: %d", d.expectedLaterEncodingTags)) + } + case reflect.Slice: + if dstType.Elem().Kind() != reflect.Uint8 || len(d.expectedLaterEncodingTags) > 0 { + // Either the destination is not a slice of bytes, or the encoder that + // produced the input indicated an expected text encoding tag and therefore + // the content of the byte string has NOT been text encoded. + return src, false, nil + } + + switch d.dm.byteSliceExpectedEncoding { + case ByteSliceExpectedEncodingBase64URL: + decoded := make([]byte, base64.RawURLEncoding.DecodedLen(len(src))) + n, err := base64.RawURLEncoding.Decode(decoded, src) + if err != nil { + return nil, false, fmt.Errorf("cbor: failed to decode base64url string: %v", err) + } + return decoded[:n], true, nil + case ByteSliceExpectedEncodingBase64: + decoded := make([]byte, base64.StdEncoding.DecodedLen(len(src))) + n, err := base64.StdEncoding.Decode(decoded, src) + if err != nil { + return nil, false, fmt.Errorf("cbor: failed to decode base64 string: %v", err) + } + return decoded[:n], true, nil + case ByteSliceExpectedEncodingBase16: + decoded := make([]byte, hex.DecodedLen(len(src))) + n, err := hex.Decode(decoded, src) + if err != nil { + return nil, false, fmt.Errorf("cbor: failed to decode hex string: %v", err) + } + return decoded[:n], true, nil + } + } + + return src, false, nil +} + // parseTextString parses CBOR encoded text string. It returns a byte slice // to prevent creating an extra copy of string. Caller should wrap returned // byte slice as string when needed. @@ -2588,6 +2764,7 @@ func (d *decoder) foundBreak() bool { func (d *decoder) reset(data []byte) { d.data = data d.off = 0 + d.expectedLaterEncodingTags = d.expectedLaterEncodingTags[:0] } func (d *decoder) nextCBORType() cborType { @@ -2721,7 +2898,7 @@ func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts B } return errors.New("cbor: cannot set new value for " + v.Type().String()) } - if bsts == ByteStringToStringAllowed && v.Kind() == reflect.String { + if bsts != ByteStringToStringForbidden && v.Kind() == reflect.String { v.SetString(string(val)) return nil } @@ -2831,6 +3008,13 @@ func validBuiltinTag(tagNum uint64, contentHead byte) error { return errors.New("cbor: tag number 2 or 3 must be followed by byte string, got " + t.String()) } return nil + case expectedLaterEncodingBase64URLTagNum, expectedLaterEncodingBase64TagNum, expectedLaterEncodingBase16TagNum: + // From RFC 8949 3.4.5.2: + // The data item tagged can be a byte string or any other data item. In the latter + // case, the tag applies to all of the byte string data items contained in the data + // item, except for those contained in a nested data item tagged with an expected + // conversion. + return nil } return nil } diff --git a/decode_test.go b/decode_test.go index 54f4b074..7a9c9018 100644 --- a/decode_test.go +++ b/decode_test.go @@ -4899,29 +4899,30 @@ func TestDecOptions(t *testing.T) { } opts1 := DecOptions{ - DupMapKey: DupMapKeyEnforcedAPF, - TimeTag: DecTagRequired, - MaxNestedLevels: 100, - MaxArrayElements: 102, - MaxMapPairs: 101, - IndefLength: IndefLengthForbidden, - TagsMd: TagsForbidden, - IntDec: IntDecConvertSigned, - MapKeyByteString: MapKeyByteStringForbidden, - ExtraReturnErrors: ExtraDecErrorUnknownField, - DefaultMapType: reflect.TypeOf(map[string]interface{}(nil)), - UTF8: UTF8DecodeInvalid, - FieldNameMatching: FieldNameMatchingCaseSensitive, - BigIntDec: BigIntDecodePointer, - DefaultByteStringType: reflect.TypeOf(""), - ByteStringToString: ByteStringToStringAllowed, - FieldNameByteString: FieldNameByteStringAllowed, - UnrecognizedTagToAny: UnrecognizedTagContentToAny, - TimeTagToAny: TimeTagToRFC3339, - SimpleValues: simpleValues, - NaN: NaNDecodeForbidden, - Inf: InfDecodeForbidden, - ByteStringToTime: ByteStringToTimeAllowed, + DupMapKey: DupMapKeyEnforcedAPF, + TimeTag: DecTagRequired, + MaxNestedLevels: 100, + MaxArrayElements: 102, + MaxMapPairs: 101, + IndefLength: IndefLengthForbidden, + TagsMd: TagsForbidden, + IntDec: IntDecConvertSigned, + MapKeyByteString: MapKeyByteStringForbidden, + ExtraReturnErrors: ExtraDecErrorUnknownField, + DefaultMapType: reflect.TypeOf(map[string]interface{}(nil)), + UTF8: UTF8DecodeInvalid, + FieldNameMatching: FieldNameMatchingCaseSensitive, + BigIntDec: BigIntDecodePointer, + DefaultByteStringType: reflect.TypeOf(""), + ByteStringToString: ByteStringToStringAllowed, + FieldNameByteString: FieldNameByteStringAllowed, + UnrecognizedTagToAny: UnrecognizedTagContentToAny, + TimeTagToAny: TimeTagToRFC3339, + SimpleValues: simpleValues, + NaN: NaNDecodeForbidden, + Inf: InfDecodeForbidden, + ByteStringToTime: ByteStringToTimeAllowed, + ByteSliceExpectedEncoding: ByteSliceToByteStringWithExpectedConversionToBase64, } ov := reflect.ValueOf(opts1) for i := 0; i < ov.NumField(); i++ { @@ -9512,3 +9513,416 @@ func TestDecModeByteStringToTime(t *testing.T) { }) } } + +func TestInvalidByteSliceExpectedEncodingMode(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: DecOptions{ByteSliceExpectedEncoding: -1}, + wantErrorMsg: "cbor: invalid ByteSliceExpectedEncoding -1", + }, + { + name: "above range of valid modes", + opts: DecOptions{ByteSliceExpectedEncoding: 101}, + wantErrorMsg: "cbor: invalid ByteSliceExpectedEncoding 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.DecMode() + if err == nil { + t.Errorf("DecMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + +func TestDecOptionsConflictWithRegisteredTags(t *testing.T) { + type empty struct{} + + for _, tc := range []struct { + name string + opts DecOptions + tags func(TagSet) error + wantErr string + }{ + { + name: "base64url encoding tag ignored by default", + opts: DecOptions{}, + tags: func(tags TagSet) error { + return tags.Add(TagOptions{DecTag: DecTagOptional}, reflect.TypeOf(empty{}), 21) + }, + wantErr: "", + }, + { + name: "base64url encoding tag conflicts in ByteStringToStringAllowedWithExpectedLaterEncoding mode", + opts: DecOptions{ByteStringToString: ByteStringToStringAllowedWithExpectedLaterEncoding}, + tags: func(tags TagSet) error { + return tags.Add(TagOptions{DecTag: DecTagOptional}, reflect.TypeOf(empty{}), 21) + }, + wantErr: "cbor: DecMode with non-default StringExpectedEncoding or ByteSliceExpectedEncoding treats tag 21 as built-in and conflicts with the provided TagSet's registration of cbor.empty", + }, + { + name: "base64url encoding tag conflicts with non-default ByteSliceExpectedEncoding option", + opts: DecOptions{ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase16}, + tags: func(tags TagSet) error { + return tags.Add(TagOptions{DecTag: DecTagOptional}, reflect.TypeOf(empty{}), 21) + }, + wantErr: "cbor: DecMode with non-default StringExpectedEncoding or ByteSliceExpectedEncoding treats tag 21 as built-in and conflicts with the provided TagSet's registration of cbor.empty", + }, + { + name: "base64 encoding tag ignored by default", + opts: DecOptions{}, + tags: func(tags TagSet) error { + return tags.Add(TagOptions{DecTag: DecTagOptional}, reflect.TypeOf(empty{}), 22) + }, + wantErr: "", + }, + { + name: "base64 encoding tag conflicts in ByteStringToStringAllowedWithExpectedLaterEncoding mode", + opts: DecOptions{ByteStringToString: ByteStringToStringAllowedWithExpectedLaterEncoding}, + tags: func(tags TagSet) error { + return tags.Add(TagOptions{DecTag: DecTagOptional}, reflect.TypeOf(empty{}), 22) + }, + wantErr: "cbor: DecMode with non-default StringExpectedEncoding or ByteSliceExpectedEncoding treats tag 22 as built-in and conflicts with the provided TagSet's registration of cbor.empty", + }, + { + name: "base64 encoding tag conflicts with non-default ByteSliceExpectedEncoding option", + opts: DecOptions{ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase16}, + tags: func(tags TagSet) error { + return tags.Add(TagOptions{DecTag: DecTagOptional}, reflect.TypeOf(empty{}), 22) + }, + wantErr: "cbor: DecMode with non-default StringExpectedEncoding or ByteSliceExpectedEncoding treats tag 22 as built-in and conflicts with the provided TagSet's registration of cbor.empty", + }, + { + name: "base16 encoding tag ignored by default", + opts: DecOptions{}, + tags: func(tags TagSet) error { + return tags.Add(TagOptions{DecTag: DecTagOptional}, reflect.TypeOf(empty{}), 23) + }, + wantErr: "", + }, + { + name: "base16 encoding tag conflicts in ByteStringToStringAllowedWithExpectedLaterEncoding mode", + opts: DecOptions{ByteStringToString: ByteStringToStringAllowedWithExpectedLaterEncoding}, + tags: func(tags TagSet) error { + return tags.Add(TagOptions{DecTag: DecTagOptional}, reflect.TypeOf(empty{}), 23) + }, + wantErr: "cbor: DecMode with non-default StringExpectedEncoding or ByteSliceExpectedEncoding treats tag 23 as built-in and conflicts with the provided TagSet's registration of cbor.empty", + }, + { + name: "base16 encoding tag conflicts with non-default ByteSliceExpectedEncoding option", + opts: DecOptions{ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase16}, + tags: func(tags TagSet) error { + return tags.Add(TagOptions{DecTag: DecTagOptional}, reflect.TypeOf(empty{}), 23) + }, + wantErr: "cbor: DecMode with non-default StringExpectedEncoding or ByteSliceExpectedEncoding treats tag 23 as built-in and conflicts with the provided TagSet's registration of cbor.empty", + }, + } { + t.Run(tc.name, func(t *testing.T) { + tags := NewTagSet() + if err := tc.tags(tags); err != nil { + t.Fatal(err) + } + + if _, err := tc.opts.DecModeWithTags(tags); err == nil { + if tc.wantErr != "" { + t.Errorf("got nil error from DecModeWithTags, want %q", tc.wantErr) + } + } else if got := err.Error(); got != tc.wantErr { + if tc.wantErr != "" { + t.Errorf("unexpected error from DecModeWithTags, got %q want %q", got, tc.wantErr) + } else { + t.Errorf("want nil error from DecModeWithTags, got %q", got) + } + } + + if _, err := tc.opts.DecModeWithSharedTags(tags); err == nil { + if tc.wantErr != "" { + t.Errorf("got nil error from DecModeWithSharedTags, want %q", tc.wantErr) + } + } else if got := err.Error(); got != tc.wantErr { + if tc.wantErr != "" { + t.Errorf("unexpected error from DecModeWithSharedTags, got %q want %q", got, tc.wantErr) + } else { + t.Errorf("want nil error from DecModeWithSharedTags, got %q", got) + } + } + }) + } +} + +func TestUnmarshalByteStringTextConversionError(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + dstType reflect.Type + in []byte + wantErr string + }{ + { + name: "reject untagged byte string containing invalid base64url", + opts: DecOptions{ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase64URL}, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0x41, 0x00}, + wantErr: "cbor: failed to decode base64url string: illegal base64 data at input byte 0", + }, + { + name: "reject untagged byte string containing invalid base64url", + opts: DecOptions{ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase64}, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0x41, 0x00}, + wantErr: "cbor: failed to decode base64 string: illegal base64 data at input byte 0", + }, + { + name: "reject untagged byte string containing invalid base16", + opts: DecOptions{ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase16}, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0x41, 0x00}, + wantErr: "cbor: failed to decode hex string: encoding/hex: invalid byte: U+0000", + }, + { + name: "accept tagged byte string containing invalid base64url", + opts: DecOptions{ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase64URL}, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0xd5, 0x41, 0x00}, + wantErr: "", + }, + { + name: "accept tagged byte string containing invalid base64url", + opts: DecOptions{ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase64}, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0xd5, 0x41, 0x00}, + wantErr: "", + }, + { + name: "accept tagged byte string containing invalid base16", + opts: DecOptions{ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase16}, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0xd5, 0x41, 0x00}, + wantErr: "", + }, + } { + t.Run(tc.name, func(t *testing.T) { + dm, err := tc.opts.DecMode() + if err != nil { + t.Fatal(err) + } + + if err := dm.Unmarshal(tc.in, reflect.New(tc.dstType).Interface()); err == nil { + if tc.wantErr != "" { + t.Errorf("got nil error, want %q", tc.wantErr) + } + } else if got := err.Error(); got != tc.wantErr { + if tc.wantErr == "" { + t.Errorf("expected nil error, got %q", got) + } else { + t.Errorf("unexpected error, got %q want %q", got, tc.wantErr) + } + } + }) + } +} + +func TestUnmarshalByteStringTextConversion(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + dstType reflect.Type + in []byte + want interface{} + }{ + { + name: "untagged into string", + opts: DecOptions{ + ByteStringToString: ByteStringToStringAllowedWithExpectedLaterEncoding, + }, + dstType: reflect.TypeOf(""), + in: []byte{0x41, 0xff}, // h'ff' + want: "\xff", + }, + { + name: "tagged base64url into string", + opts: DecOptions{ + ByteStringToString: ByteStringToStringAllowedWithExpectedLaterEncoding, + }, + dstType: reflect.TypeOf(""), + in: []byte{0xd5, 0x41, 0xff}, // 21(h'ff') + want: "_w", + }, + { + name: "indirectly tagged base64url into string", + opts: DecOptions{ + ByteStringToString: ByteStringToStringAllowedWithExpectedLaterEncoding, + }, + dstType: reflect.TypeOf(""), + in: []byte{0xd5, 0xd9, 0xd9, 0xf7, 0x41, 0xff}, // 21(55799(h'ff')) + want: "_w", + }, + { + name: "tagged base64url into string tags ignored", + opts: DecOptions{ + ByteStringToString: ByteStringToStringAllowed, + }, + dstType: reflect.TypeOf(""), + in: []byte{0xd5, 0x41, 0xff}, // 21(h'ff') + want: "\xff", + }, + { + name: "tagged into []byte with default encoding base64url", + opts: DecOptions{ + ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase64URL, + }, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0xd5, 0x41, 0xff}, // 21(h'ff') + want: []byte{0xff}, + }, + { + name: "indirectly tagged into []byte with default encoding base64url", + opts: DecOptions{ + ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase64URL, + }, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0xd5, 0xd9, 0xd9, 0xf7, 0x41, 0xff}, // 21(55799(h'ff')) + want: []byte{0xff}, + }, + { + name: "untagged base64url into []byte with default encoding base64url", + opts: DecOptions{ + ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase64URL, + }, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0x42, '_', 'w'}, // '_w' + want: []byte{0xff}, + }, + { + name: "tagged base64 into string", + opts: DecOptions{ + ByteStringToString: ByteStringToStringAllowedWithExpectedLaterEncoding, + }, + dstType: reflect.TypeOf(""), + in: []byte{0xd6, 0x41, 0xff}, // 22(h'ff') + want: "/w==", + }, + { + name: "indirectly tagged base64 into string", + opts: DecOptions{ + ByteStringToString: ByteStringToStringAllowedWithExpectedLaterEncoding, + }, + dstType: reflect.TypeOf(""), + in: []byte{0xd6, 0xd9, 0xd9, 0xf7, 0x41, 0xff}, // 22(55799(h'ff')) + want: "/w==", + }, + { + name: "tagged base64 into string tags ignored", + opts: DecOptions{ + ByteStringToString: ByteStringToStringAllowed, + }, + dstType: reflect.TypeOf(""), + in: []byte{0xd6, 0x41, 0xff}, // 22(h'ff') + want: "\xff", + }, + { + name: "tagged into []byte with default encoding base64", + opts: DecOptions{ + ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase64, + }, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0xd6, 0x41, 0xff}, // 22(h'ff') + want: []byte{0xff}, + }, + { + name: "indirectly tagged into []byte with default encoding base64", + opts: DecOptions{ + ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase64, + }, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0xd6, 0xd9, 0xd9, 0xf7, 0x41, 0xff}, // 22(55799(h'ff')) + want: []byte{0xff}, + }, + { + name: "untagged base64 into []byte with default encoding base64", + opts: DecOptions{ + ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase64, + }, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0x44, '/', 'w', '=', '='}, // '/w==' + want: []byte{0xff}, + }, + { + name: "tagged base16 into string", + opts: DecOptions{ + ByteStringToString: ByteStringToStringAllowedWithExpectedLaterEncoding, + }, + dstType: reflect.TypeOf(""), + in: []byte{0xd7, 0x41, 0xff}, // 23(h'ff') + want: "ff", + }, + { + name: "indirectly tagged base16 into string", + opts: DecOptions{ + ByteStringToString: ByteStringToStringAllowedWithExpectedLaterEncoding, + }, + dstType: reflect.TypeOf(""), + in: []byte{0xd7, 0xd9, 0xd9, 0xf7, 0x41, 0xff}, // 23(55799(h'ff')) + want: "ff", + }, + { + name: "tagged base16 into string tags ignored", + opts: DecOptions{ + ByteStringToString: ByteStringToStringAllowed, + }, + dstType: reflect.TypeOf(""), + in: []byte{0xd7, 0x41, 0xff}, // 23(h'ff') + want: "\xff", + }, + { + name: "tagged into []byte with default encoding base16", + opts: DecOptions{ + ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase16, + }, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0xd7, 0x41, 0xff}, // 23(h'ff') + want: []byte{0xff}, + }, + { + name: "indirectly tagged into []byte with default encoding base16", + opts: DecOptions{ + ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase16, + }, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0xd7, 0xd9, 0xd9, 0xf7, 0x41, 0xff}, // 23(55799(h'ff')) + want: []byte{0xff}, + }, + { + name: "untagged base16 into []byte with default encoding base16", + opts: DecOptions{ + ByteSliceExpectedEncoding: ByteSliceExpectedEncodingBase16, + }, + dstType: reflect.TypeOf([]byte{}), + in: []byte{0x42, 'f', 'f'}, + want: []byte{0xff}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + dm, err := tc.opts.DecMode() + if err != nil { + t.Fatal(err) + } + + dstVal := reflect.New(tc.dstType) + if err := dm.Unmarshal(tc.in, dstVal.Interface()); err != nil { + t.Fatal(err) + } + + if dst := dstVal.Elem().Interface(); !reflect.DeepEqual(dst, tc.want) { + t.Errorf("got: %#v, want %#v", dst, tc.want) + } + }) + } +} diff --git a/encode.go b/encode.go index a58c3a54..85a79306 100644 --- a/encode.go +++ b/encode.go @@ -387,6 +387,62 @@ func (fnm FieldNameMode) valid() bool { return fnm >= 0 && fnm < maxFieldNameMode } +// ByteSliceMode specifies how to encode slices of bytes. +type ByteSliceMode int + +const ( + // ByteSliceToByteString encodes slices of bytes to CBOR byte string (major type 2). + ByteSliceToByteString = iota + + // ByteSliceToByteStringWithExpectedConversionToBase64URL encodes slices of bytes to CBOR + // byte string (major type 2) inside tag 21 (expected conversion to base64url encoding, see + // RFC 8949 Section 3.4.5.2). + ByteSliceToByteStringWithExpectedConversionToBase64URL + + // ByteSliceToByteStringWithExpectedConversionToBase64 encodes slices of bytes to CBOR byte + // string (major type 2) inside tag 22 (expected conversion to base64 encoding, see RFC 8949 + // Section 3.4.5.2). + ByteSliceToByteStringWithExpectedConversionToBase64 + + // ByteSliceToByteStringWithExpectedConversionToBase16 encodes slices of bytes to CBOR byte + // string (major type 2) inside tag 23 (expected conversion to base16 encoding, see RFC 8949 + // Section 3.4.5.2). + ByteSliceToByteStringWithExpectedConversionToBase16 +) + +func (bsm ByteSliceMode) encodingTag() (uint64, error) { + switch bsm { + case ByteSliceToByteString: + return 0, nil + case ByteSliceToByteStringWithExpectedConversionToBase64URL: + return expectedLaterEncodingBase64URLTagNum, nil + case ByteSliceToByteStringWithExpectedConversionToBase64: + return expectedLaterEncodingBase64TagNum, nil + case ByteSliceToByteStringWithExpectedConversionToBase16: + return expectedLaterEncodingBase16TagNum, nil + } + return 0, errors.New("cbor: invalid ByteSlice " + strconv.Itoa(int(bsm))) +} + +// ByteArrayMode specifies how to encode byte arrays. +type ByteArrayMode int + +const ( + // ByteArrayToByteSlice encodes byte arrays the same way that a byte slice with identical + // length and contents is encoded. + ByteArrayToByteSlice = iota + + // ByteArrayToArray encodes byte arrays to the CBOR array type with one unsigned integer + // item for each byte in the array. + ByteArrayToArray + + maxByteArrayMode +) + +func (bam ByteArrayMode) valid() bool { + return bam >= 0 && bam < maxByteArrayMode +} + // EncOptions specifies encoding options. type EncOptions struct { // Sort specifies sorting order. @@ -431,6 +487,12 @@ type EncOptions struct { // FieldName specifies the CBOR type to use when encoding struct field names. FieldName FieldNameMode + + // ByteSlice specifies how to encode byte slices. + ByteSlice ByteSliceMode + + // ByteArray specifies how to encode byte arrays. + ByteArray ByteArrayMode } // CanonicalEncOptions returns EncOptions for "Canonical CBOR" encoding, @@ -616,21 +678,31 @@ func (opts EncOptions) encMode() (*encMode, error) { if !opts.FieldName.valid() { return nil, errors.New("cbor: invalid FieldName " + strconv.Itoa(int(opts.FieldName))) } + byteSliceEncodingTag, err := opts.ByteSlice.encodingTag() + if err != nil { + return nil, err + } + if !opts.ByteArray.valid() { + return nil, errors.New("cbor: invalid ByteArray " + strconv.Itoa(int(opts.ByteArray))) + } em := encMode{ - sort: opts.Sort, - shortestFloat: opts.ShortestFloat, - nanConvert: opts.NaNConvert, - infConvert: opts.InfConvert, - bigIntConvert: opts.BigIntConvert, - time: opts.Time, - timeTag: opts.TimeTag, - indefLength: opts.IndefLength, - nilContainers: opts.NilContainers, - tagsMd: opts.TagsMd, - omitEmpty: opts.OmitEmpty, - stringType: opts.String, - stringMajorType: stringMajorType, - fieldName: opts.FieldName, + sort: opts.Sort, + shortestFloat: opts.ShortestFloat, + nanConvert: opts.NaNConvert, + infConvert: opts.InfConvert, + bigIntConvert: opts.BigIntConvert, + time: opts.Time, + timeTag: opts.TimeTag, + indefLength: opts.IndefLength, + nilContainers: opts.NilContainers, + tagsMd: opts.TagsMd, + omitEmpty: opts.OmitEmpty, + stringType: opts.String, + stringMajorType: stringMajorType, + fieldName: opts.FieldName, + byteSlice: opts.ByteSlice, + byteSliceEncodingTag: byteSliceEncodingTag, + byteArray: opts.ByteArray, } return &em, nil } @@ -643,21 +715,24 @@ type EncMode interface { } type encMode struct { - tags tagProvider - sort SortMode - shortestFloat ShortestFloatMode - nanConvert NaNConvertMode - infConvert InfConvertMode - bigIntConvert BigIntConvertMode - time TimeMode - timeTag EncTagMode - indefLength IndefLengthMode - nilContainers NilContainersMode - tagsMd TagsMode - omitEmpty OmitEmptyMode - stringType StringMode - stringMajorType cborType - fieldName FieldNameMode + tags tagProvider + sort SortMode + shortestFloat ShortestFloatMode + nanConvert NaNConvertMode + infConvert InfConvertMode + bigIntConvert BigIntConvertMode + time TimeMode + timeTag EncTagMode + indefLength IndefLengthMode + nilContainers NilContainersMode + tagsMd TagsMode + omitEmpty OmitEmptyMode + stringType StringMode + stringMajorType cborType + fieldName FieldNameMode + byteSlice ByteSliceMode + byteSliceEncodingTag uint64 + byteArray ByteArrayMode } var defaultEncMode, _ = EncOptions{}.encMode() @@ -747,6 +822,8 @@ func (em *encMode) EncOptions() EncOptions { OmitEmpty: em.omitEmpty, String: em.stringType, FieldName: em.fieldName, + ByteSlice: em.byteSlice, + ByteArray: em.byteArray, } } @@ -1026,6 +1103,9 @@ func encodeByteString(e *encoderBuffer, em *encMode, v reflect.Value) error { e.Write(cborNil) return nil } + if vk == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 && em.byteSliceEncodingTag != 0 { + encodeHead(e, byte(cborTypeTag), em.byteSliceEncodingTag) + } if b := em.encTagBytes(v.Type()); b != nil { e.Write(b) } @@ -1059,6 +1139,9 @@ type arrayEncodeFunc struct { } func (ae arrayEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) error { + if em.byteArray == ByteArrayToByteSlice && v.Type().Elem().Kind() == reflect.Uint8 { + return encodeByteString(e, em, v) + } if v.Kind() == reflect.Slice && v.IsNil() && em.nilContainers == NilContainerAsNull { e.Write(cborNil) return nil @@ -1570,10 +1653,12 @@ func getEncodeFuncInternal(t reflect.Type) (encodeFunc, isEmptyFunc) { return encodeFloat, isEmptyFloat case reflect.String: return encodeString, isEmptyString - case reflect.Slice, reflect.Array: + case reflect.Slice: if t.Elem().Kind() == reflect.Uint8 { return encodeByteString, isEmptySlice } + fallthrough + case reflect.Array: f, _ := getEncodeFunc(t.Elem()) if f == nil { return nil, nil diff --git a/encode_test.go b/encode_test.go index 25889736..bc7654ee 100644 --- a/encode_test.go +++ b/encode_test.go @@ -3683,6 +3683,8 @@ func TestEncOptions(t *testing.T) { OmitEmpty: OmitEmptyGoValue, String: StringToByteString, FieldName: FieldNameToByteString, + ByteSlice: ByteSliceToByteStringWithExpectedConversionToBase16, + ByteArray: ByteArrayToArray, } ov := reflect.ValueOf(opts1) for i := 0; i < ov.NumField(); i++ { @@ -4421,3 +4423,202 @@ func TestSortModeFastShuffle(t *testing.T) { }) } } + +func TestInvalidByteSlice(t *testing.T) { + for _, tc := range []struct { + name string + opts EncOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: EncOptions{ByteSlice: -1}, + wantErrorMsg: "cbor: invalid ByteSlice -1", + }, + { + name: "above range of valid modes", + opts: EncOptions{ByteSlice: 101}, + wantErrorMsg: "cbor: invalid ByteSlice 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.EncMode() + if err == nil { + t.Errorf("EncMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("EncMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + +func TestInvalidByteArray(t *testing.T) { + for _, tc := range []struct { + name string + opts EncOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: EncOptions{ByteArray: -1}, + wantErrorMsg: "cbor: invalid ByteArray -1", + }, + { + name: "above range of valid modes", + opts: EncOptions{ByteArray: 101}, + wantErrorMsg: "cbor: invalid ByteArray 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.EncMode() + if err == nil { + t.Errorf("EncMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("EncMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + +func TestMarshalByteArrayMode(t *testing.T) { + for _, tc := range []struct { + name string + opts EncOptions + in interface{} + expected []byte + }{ + { + name: "byte array treated as byte slice by default", + opts: EncOptions{}, + in: [1]byte{}, + expected: []byte{0x41, 0x00}, + }, + { + name: "byte array treated as byte slice with ByteArrayAsByteSlice", + opts: EncOptions{ByteArray: ByteArrayToByteSlice}, + in: [1]byte{}, + expected: []byte{0x41, 0x00}, + }, + { + name: "byte array treated as array of integers with ByteArrayToArray", + opts: EncOptions{ByteArray: ByteArrayToArray}, + in: [1]byte{}, + expected: []byte{0x81, 0x00}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + em, err := tc.opts.EncMode() + if err != nil { + t.Fatal(err) + } + + out, err := em.Marshal(tc.in) + if err != nil { + t.Fatal(err) + } + + if string(out) != string(tc.expected) { + t.Errorf("unexpected output, got 0x%x want 0x%x", out, tc.expected) + } + }) + } +} + +func TestMarshalByteSliceMode(t *testing.T) { + type namedByteSlice []byte + ts := NewTagSet() + if err := ts.Add(TagOptions{EncTag: EncTagRequired}, reflect.TypeOf(namedByteSlice{}), 0xcc); err != nil { + t.Fatal(err) + } + + for _, tc := range []struct { + name string + tags TagSet + opts EncOptions + in interface{} + expected []byte + }{ + { + name: "byte slice marshals to byte string by default", + opts: EncOptions{}, + in: []byte{0xbb}, + expected: []byte{0x41, 0xbb}, + }, + { + name: "byte slice marshals to byte string by with ByteSliceToByteString", + opts: EncOptions{ByteSlice: ByteSliceToByteString}, + in: []byte{0xbb}, + expected: []byte{0x41, 0xbb}, + }, + { + name: "byte slice marshaled to byte string enclosed in base64url expected encoding tag", + opts: EncOptions{ByteSlice: ByteSliceToByteStringWithExpectedConversionToBase64URL}, + in: []byte{0xbb}, + expected: []byte{0xd5, 0x41, 0xbb}, + }, + { + name: "byte slice marshaled to byte string enclosed in base64 expected encoding tag", + opts: EncOptions{ByteSlice: ByteSliceToByteStringWithExpectedConversionToBase64}, + in: []byte{0xbb}, + expected: []byte{0xd6, 0x41, 0xbb}, + }, + { + name: "byte slice marshaled to byte string enclosed in base16 expected encoding tag", + opts: EncOptions{ByteSlice: ByteSliceToByteStringWithExpectedConversionToBase16}, + in: []byte{0xbb}, + expected: []byte{0xd7, 0x41, 0xbb}, + }, + { + name: "user-registered tag numbers are encoded with no expected encoding tag", + tags: ts, + opts: EncOptions{ByteSlice: ByteSliceToByteString}, + in: namedByteSlice{0xbb}, + expected: []byte{0xd8, 0xcc, 0x41, 0xbb}, + }, + { + name: "user-registered tag numbers are encoded after base64url expected encoding tag", + tags: ts, + opts: EncOptions{ByteSlice: ByteSliceToByteStringWithExpectedConversionToBase64URL}, + in: namedByteSlice{0xbb}, + expected: []byte{0xd5, 0xd8, 0xcc, 0x41, 0xbb}, + }, + { + name: "user-registered tag numbers are encoded after base64 expected encoding tag", + tags: ts, + opts: EncOptions{ByteSlice: ByteSliceToByteStringWithExpectedConversionToBase64}, + in: namedByteSlice{0xbb}, + expected: []byte{0xd6, 0xd8, 0xcc, 0x41, 0xbb}, + }, + { + name: "user-registered tag numbers are encoded after base16 expected encoding tag", + tags: ts, + opts: EncOptions{ByteSlice: ByteSliceToByteStringWithExpectedConversionToBase16}, + in: namedByteSlice{0xbb}, + expected: []byte{0xd7, 0xd8, 0xcc, 0x41, 0xbb}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + var em EncMode + if tc.tags != nil { + var err error + if em, err = tc.opts.EncModeWithTags(tc.tags); err != nil { + t.Fatal(err) + } + } else { + var err error + if em, err = tc.opts.EncMode(); err != nil { + t.Fatal(err) + } + } + + out, err := em.Marshal(tc.in) + if err != nil { + t.Fatal(err) + } + + if string(out) != string(tc.expected) { + t.Errorf("unexpected output, got 0x%x want 0x%x", out, tc.expected) + } + }) + } +} diff --git a/json_test.go b/json_test.go new file mode 100644 index 00000000..e1cc8736 --- /dev/null +++ b/json_test.go @@ -0,0 +1,132 @@ +package cbor_test + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/fxamacker/cbor/v2" +) + +// TestStdlibJSONCompatibility tests compatibility as a drop-in replacement for the standard library +// encoding/json package on a round trip encoding from Go object to interface{}. +func TestStdlibJSONCompatibility(t *testing.T) { + // TODO: With better coverage and compatibility, it could be useful to expose these option + // configurations to users. + + enc, err := cbor.EncOptions{ + ByteSlice: cbor.ByteSliceToByteStringWithExpectedConversionToBase64, + String: cbor.StringToByteString, + ByteArray: cbor.ByteArrayToArray, + }.EncMode() + if err != nil { + t.Fatal(err) + } + + dec, err := cbor.DecOptions{ + DefaultByteStringType: reflect.TypeOf(""), + ByteStringToString: cbor.ByteStringToStringAllowedWithExpectedLaterEncoding, + ByteSliceExpectedEncoding: cbor.ByteSliceExpectedEncodingBase64, + }.DecMode() + if err != nil { + t.Fatal(err) + } + + for _, tc := range []struct { + name string + original interface{} + ifaceEqual bool // require equal intermediate interface{} values from both protocols + }{ + { + name: "byte slice to base64-encoded string", + original: []byte("hello world"), + ifaceEqual: true, + }, + { + name: "byte array to array of integers", + original: [11]byte{'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'}, + ifaceEqual: false, // encoding/json decodes the array elements to float64 + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Logf("original: %#v", tc.original) + + j1, err := json.Marshal(tc.original) + if err != nil { + t.Fatal(err) + } + t.Logf("original to json: %s", string(j1)) + + c1, err := enc.Marshal(tc.original) + if err != nil { + t.Fatal(err) + } + diag1, err := cbor.Diagnose(c1) + if err != nil { + t.Fatal(err) + } + t.Logf("original to cbor: %s", diag1) + + var jintf interface{} + err = json.Unmarshal(j1, &jintf) + if err != nil { + t.Fatal(err) + } + t.Logf("json to interface{} (%T): %#v", jintf, jintf) + + var cintf interface{} + err = dec.Unmarshal(c1, &cintf) + if err != nil { + t.Fatal(err) + } + t.Logf("cbor to interface{} (%T): %#v", cintf, cintf) + + j2, err := json.Marshal(jintf) + if err != nil { + t.Fatal(err) + } + t.Logf("interface{} to json: %s", string(j2)) + + c2, err := enc.Marshal(cintf) + if err != nil { + t.Fatal(err) + } + diag2, err := cbor.Diagnose(c2) + if err != nil { + t.Fatal(err) + } + t.Logf("interface{} to cbor: %s", diag2) + + if !reflect.DeepEqual(jintf, cintf) { + if tc.ifaceEqual { + t.Errorf("native-to-interface{} via cbor differed from native-to-interface{} via json") + } else { + t.Logf("native-to-interface{} via cbor differed from native-to-interface{} via json") + } + } + + jfinalValue := reflect.New(reflect.TypeOf(tc.original)) + err = json.Unmarshal(j2, jfinalValue.Interface()) + if err != nil { + t.Fatal(err) + } + jfinal := jfinalValue.Elem().Interface() + t.Logf("json to native: %#v", jfinal) + if !reflect.DeepEqual(tc.original, jfinal) { + t.Error("diff in json roundtrip") + } + + cfinalValue := reflect.New(reflect.TypeOf(tc.original)) + err = dec.Unmarshal(c2, cfinalValue.Interface()) + if err != nil { + t.Fatal(err) + } + cfinal := cfinalValue.Elem().Interface() + t.Logf("cbor to native: %#v", cfinal) + if !reflect.DeepEqual(tc.original, cfinal) { + t.Error("diff in cbor roundtrip") + } + + }) + } +}