Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve byte string format decoding options #550

Merged
merged 2 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 161 additions & 118 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,38 @@ func (e UnacceptableDataItemError) Error() string {
return fmt.Sprintf("cbor: data item of cbor type %s is not accepted by protocol: %s", e.CBORType, e.Message)
}

// ByteStringExpectedFormatError is returned when unmarshaling CBOR byte string fails when
// using non-default ByteStringExpectedFormat decoding option that makes decoder expect
// a specified format such as base64, hex, etc.
type ByteStringExpectedFormatError struct {
expectedFormatOption ByteStringExpectedFormatMode
err error
}

func newByteStringExpectedFormatError(expectedFormatOption ByteStringExpectedFormatMode, err error) *ByteStringExpectedFormatError {
return &ByteStringExpectedFormatError{expectedFormatOption, err}
}

func (e *ByteStringExpectedFormatError) Error() string {
switch e.expectedFormatOption {
case ByteStringExpectedBase64URL:
return fmt.Sprintf("cbor: failed to decode base64url from byte string: %s", e.err)

case ByteStringExpectedBase64:
return fmt.Sprintf("cbor: failed to decode base64 from byte string: %s", e.err)

case ByteStringExpectedBase16:
return fmt.Sprintf("cbor: failed to decode hex from byte string: %s", e.err)

default:
return fmt.Sprintf("cbor: failed to decode byte string in expected format %d: %s", e.expectedFormatOption, e.err)
}
}

func (e *ByteStringExpectedFormatError) Unwrap() error {
return e.err
}

// DupMapKeyMode specifies how to enforce duplicate map key. Two map keys are considered duplicates if:
// 1. When decoding into a struct, both keys match the same struct field. The keys are also
// considered duplicates if neither matches any field and decoding to interface{} would produce
Expand Down Expand Up @@ -602,32 +634,38 @@ func (bttm ByteStringToTimeMode) valid() bool {
return bttm >= 0 && bttm < maxByteStringToTimeMode
}

// ByteSliceExpectedEncodingMode specifies how to decode a byte string NOT enclosed in an "expected
// later encoding" tag (RFC 8949 Section 3.4.5.2) into a Go byte slice.
type ByteSliceExpectedEncodingMode int
// ByteStringExpectedFormatMode specifies how to decode CBOR byte string into Go byte slice
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. The old suffix "EncodingMode" for a decoding mode name would have been confusing.

// when the byte string is NOT enclosed in CBOR tag 21, 22, or 23. An error is returned if
// the CBOR byte string does not contain the expected format (e.g. base64) specified.
// For tags 21-23, see "Expected Later Encoding for CBOR-to-JSON Converters"
// in RFC 8949 Section 3.4.5.2.
type ByteStringExpectedFormatMode int

const (
// ByteSliceExpectedEncodingIgnored copies the contents of the byte string, unmodified, into
// a destination Go byte slice.
ByteSliceExpectedEncodingIgnored = iota

// ByteSliceExpectedEncodingBase64URL assumes that byte strings with no text encoding hint
// contain base64url-encoded bytes.
ByteSliceExpectedEncodingBase64URL

// ByteSliceExpectedEncodingBase64 assumes that byte strings with no text encoding hint
// contain base64-encoded bytes.
ByteSliceExpectedEncodingBase64

// ByteSliceExpectedEncodingBase16 assumes that byte strings with no text encoding hint
// contain base16-encoded bytes.
ByteSliceExpectedEncodingBase16

maxByteSliceExpectedEncodingMode
// ByteStringExpectedFormatNone copies the unmodified CBOR byte string into Go byte slice
// if the byte string is not tagged by CBOR tag 21-23.
ByteStringExpectedFormatNone ByteStringExpectedFormatMode = iota

// ByteStringExpectedBase64URL expects CBOR byte strings to contain base64url-encoded bytes
// if the byte string is not tagged by CBOR tag 21-23. The decoder will attempt to decode
// the base64url-encoded bytes into Go slice.
ByteStringExpectedBase64URL

// ByteStringExpectedBase64 expects CBOR byte strings to contain base64-encoded bytes
// if the byte string is not tagged by CBOR tag 21-23. The decoder will attempt to decode
// the base64-encoded bytes into Go slice.
ByteStringExpectedBase64

// ByteStringExpectedBase16 expects CBOR byte strings to contain base16-encoded bytes
// if the byte string is not tagged by CBOR tag 21-23. The decoder will attempt to decode
// the base16-encoded bytes into Go slice.
ByteStringExpectedBase16

maxByteStringExpectedFormatMode
)

func (bseem ByteSliceExpectedEncodingMode) valid() bool {
return bseem >= 0 && bseem < maxByteSliceExpectedEncodingMode
func (bsefm ByteStringExpectedFormatMode) valid() bool {
return bsefm >= 0 && bsefm < maxByteStringExpectedFormatMode
}

// BignumTagMode specifies whether or not the "bignum" tags 2 and 3 (RFC 8949 Section 3.4.3) can be
Expand Down Expand Up @@ -761,7 +799,7 @@ type DecOptions struct {
// Currently, recognized CBOR tag numbers are 0, 1, 2, 3, or registered by TagSet.
UnrecognizedTagToAny UnrecognizedTagToAnyMode

// TimeTagToAnyMode specifies how to decode CBOR tag 0 and 1 into an empty interface (any).
// TimeTagToAny specifies how to decode CBOR tag 0 and 1 into an empty interface (any).
// Based on the specified mode, Unmarshal can return a time.Time value or a time string in a specific format.
TimeTagToAny TimeTagToAnyMode

Expand All @@ -783,12 +821,15 @@ type DecOptions struct {
// 25 through 27) representing positive or negative infinity.
Inf InfMode

// ByteStringToTimeMode specifies the behavior when decoding a CBOR byte string into a Go time.Time.
// ByteStringToTime specifies how to decode CBOR byte string into Go time.Time.
ByteStringToTime ByteStringToTimeMode

// ByteSliceExpectedEncodingMode specifies how to decode a byte string NOT enclosed in an
// "expected later encoding" tag (RFC 8949 Section 3.4.5.2) into a Go byte slice.
ByteSliceExpectedEncoding ByteSliceExpectedEncodingMode
// ByteStringExpectedFormat specifies how to decode CBOR byte string into Go byte slice
// when the byte string is NOT enclosed in CBOR tag 21, 22, or 23. An error is returned if
// the CBOR byte string does not contain the expected format (e.g. base64) specified.
// For tags 21-23, see "Expected Later Encoding for CBOR-to-JSON Converters"
// in RFC 8949 Section 3.4.5.2.
ByteStringExpectedFormat ByteStringExpectedFormatMode

// BignumTag specifies whether or not the "bignum" tags 2 and 3 (RFC 8949 Section 3.4.3) can
// be decoded. Unlike BigIntDec, this option applies to all bignum tags encountered in a
Expand All @@ -815,7 +856,8 @@ func (opts DecOptions) validForTags(tags TagSet) error { //nolint:gocritic // ig
if tags == nil {
return errors.New("cbor: cannot create DecMode with nil value as TagSet")
}
if opts.ByteStringToString == ByteStringToStringAllowedWithExpectedLaterEncoding || opts.ByteSliceExpectedEncoding != ByteSliceExpectedEncodingIgnored {
if opts.ByteStringToString == ByteStringToStringAllowedWithExpectedLaterEncoding ||
opts.ByteStringExpectedFormat != ByteStringExpectedFormatNone {
for _, tagNum := range []uint64{
tagNumExpectedLaterEncodingBase64URL,
tagNumExpectedLaterEncodingBase64,
Expand Down Expand Up @@ -998,8 +1040,8 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore
return nil, errors.New("cbor: invalid ByteStringToTime " + strconv.Itoa(int(opts.ByteStringToTime)))
}

if !opts.ByteSliceExpectedEncoding.valid() {
return nil, errors.New("cbor: invalid ByteSliceExpectedEncoding " + strconv.Itoa(int(opts.ByteSliceExpectedEncoding)))
if !opts.ByteStringExpectedFormat.valid() {
return nil, errors.New("cbor: invalid ByteStringExpectedFormat " + strconv.Itoa(int(opts.ByteStringExpectedFormat)))
}

if !opts.BignumTag.valid() {
Expand All @@ -1011,32 +1053,32 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore
}

dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
maxNestedLevels: opts.MaxNestedLevels,
maxArrayElements: opts.MaxArrayElements,
maxMapPairs: opts.MaxMapPairs,
indefLength: opts.IndefLength,
tagsMd: opts.TagsMd,
intDec: opts.IntDec,
mapKeyByteString: opts.MapKeyByteString,
extraReturnErrors: opts.ExtraReturnErrors,
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
fieldNameMatching: opts.FieldNameMatching,
bigIntDec: opts.BigIntDec,
defaultByteStringType: opts.DefaultByteStringType,
byteStringToString: opts.ByteStringToString,
fieldNameByteString: opts.FieldNameByteString,
unrecognizedTagToAny: opts.UnrecognizedTagToAny,
timeTagToAny: opts.TimeTagToAny,
simpleValues: simpleValues,
nanDec: opts.NaN,
infDec: opts.Inf,
byteStringToTime: opts.ByteStringToTime,
byteSliceExpectedEncoding: opts.ByteSliceExpectedEncoding,
bignumTag: opts.BignumTag,
binaryUnmarshaler: opts.BinaryUnmarshaler,
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
maxNestedLevels: opts.MaxNestedLevels,
maxArrayElements: opts.MaxArrayElements,
maxMapPairs: opts.MaxMapPairs,
indefLength: opts.IndefLength,
tagsMd: opts.TagsMd,
intDec: opts.IntDec,
mapKeyByteString: opts.MapKeyByteString,
extraReturnErrors: opts.ExtraReturnErrors,
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
fieldNameMatching: opts.FieldNameMatching,
bigIntDec: opts.BigIntDec,
defaultByteStringType: opts.DefaultByteStringType,
byteStringToString: opts.ByteStringToString,
fieldNameByteString: opts.FieldNameByteString,
unrecognizedTagToAny: opts.UnrecognizedTagToAny,
timeTagToAny: opts.TimeTagToAny,
simpleValues: simpleValues,
nanDec: opts.NaN,
infDec: opts.Inf,
byteStringToTime: opts.ByteStringToTime,
byteStringExpectedFormat: opts.ByteStringExpectedFormat,
bignumTag: opts.BignumTag,
binaryUnmarshaler: opts.BinaryUnmarshaler,
}

return &dm, nil
Expand Down Expand Up @@ -1089,33 +1131,33 @@ type DecMode interface {
}

type decMode struct {
tags tagProvider
dupMapKey DupMapKeyMode
timeTag DecTagMode
maxNestedLevels int
maxArrayElements int
maxMapPairs int
indefLength IndefLengthMode
tagsMd TagsMode
intDec IntDecMode
mapKeyByteString MapKeyByteStringMode
extraReturnErrors ExtraDecErrorCond
defaultMapType reflect.Type
utf8 UTF8Mode
fieldNameMatching FieldNameMatchingMode
bigIntDec BigIntDecMode
defaultByteStringType reflect.Type
byteStringToString ByteStringToStringMode
fieldNameByteString FieldNameByteStringMode
unrecognizedTagToAny UnrecognizedTagToAnyMode
timeTagToAny TimeTagToAnyMode
simpleValues *SimpleValueRegistry
nanDec NaNMode
infDec InfMode
byteStringToTime ByteStringToTimeMode
byteSliceExpectedEncoding ByteSliceExpectedEncodingMode
bignumTag BignumTagMode
binaryUnmarshaler BinaryUnmarshalerMode
tags tagProvider
dupMapKey DupMapKeyMode
timeTag DecTagMode
maxNestedLevels int
maxArrayElements int
maxMapPairs int
indefLength IndefLengthMode
tagsMd TagsMode
intDec IntDecMode
mapKeyByteString MapKeyByteStringMode
extraReturnErrors ExtraDecErrorCond
defaultMapType reflect.Type
utf8 UTF8Mode
fieldNameMatching FieldNameMatchingMode
bigIntDec BigIntDecMode
defaultByteStringType reflect.Type
byteStringToString ByteStringToStringMode
fieldNameByteString FieldNameByteStringMode
unrecognizedTagToAny UnrecognizedTagToAnyMode
timeTagToAny TimeTagToAnyMode
simpleValues *SimpleValueRegistry
nanDec NaNMode
infDec InfMode
byteStringToTime ByteStringToTimeMode
byteStringExpectedFormat ByteStringExpectedFormatMode
bignumTag BignumTagMode
binaryUnmarshaler BinaryUnmarshalerMode
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand All @@ -1130,32 +1172,32 @@ func (dm *decMode) DecOptions() DecOptions {
}

return DecOptions{
DupMapKey: dm.dupMapKey,
TimeTag: dm.timeTag,
MaxNestedLevels: dm.maxNestedLevels,
MaxArrayElements: dm.maxArrayElements,
MaxMapPairs: dm.maxMapPairs,
IndefLength: dm.indefLength,
TagsMd: dm.tagsMd,
IntDec: dm.intDec,
MapKeyByteString: dm.mapKeyByteString,
ExtraReturnErrors: dm.extraReturnErrors,
DefaultMapType: dm.defaultMapType,
UTF8: dm.utf8,
FieldNameMatching: dm.fieldNameMatching,
BigIntDec: dm.bigIntDec,
DefaultByteStringType: dm.defaultByteStringType,
ByteStringToString: dm.byteStringToString,
FieldNameByteString: dm.fieldNameByteString,
UnrecognizedTagToAny: dm.unrecognizedTagToAny,
TimeTagToAny: dm.timeTagToAny,
SimpleValues: simpleValues,
NaN: dm.nanDec,
Inf: dm.infDec,
ByteStringToTime: dm.byteStringToTime,
ByteSliceExpectedEncoding: dm.byteSliceExpectedEncoding,
BignumTag: dm.bignumTag,
BinaryUnmarshaler: dm.binaryUnmarshaler,
DupMapKey: dm.dupMapKey,
TimeTag: dm.timeTag,
MaxNestedLevels: dm.maxNestedLevels,
MaxArrayElements: dm.maxArrayElements,
MaxMapPairs: dm.maxMapPairs,
IndefLength: dm.indefLength,
TagsMd: dm.tagsMd,
IntDec: dm.intDec,
MapKeyByteString: dm.mapKeyByteString,
ExtraReturnErrors: dm.extraReturnErrors,
DefaultMapType: dm.defaultMapType,
UTF8: dm.utf8,
FieldNameMatching: dm.fieldNameMatching,
BigIntDec: dm.bigIntDec,
DefaultByteStringType: dm.defaultByteStringType,
ByteStringToString: dm.byteStringToString,
FieldNameByteString: dm.fieldNameByteString,
UnrecognizedTagToAny: dm.unrecognizedTagToAny,
TimeTagToAny: dm.timeTagToAny,
SimpleValues: simpleValues,
NaN: dm.nanDec,
Inf: dm.infDec,
ByteStringToTime: dm.byteStringToTime,
ByteStringExpectedFormat: dm.byteStringExpectedFormat,
BignumTag: dm.bignumTag,
BinaryUnmarshaler: dm.binaryUnmarshaler,
}
}

Expand Down Expand Up @@ -1531,7 +1573,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
case tagNumExpectedLaterEncodingBase64URL, tagNumExpectedLaterEncodingBase64, tagNumExpectedLaterEncodingBase16:
// If conversion for interoperability with text encodings is not configured,
// treat tags 21-23 as unregistered tags.
if d.dm.byteStringToString == ByteStringToStringAllowedWithExpectedLaterEncoding || d.dm.byteSliceExpectedEncoding != ByteSliceExpectedEncodingIgnored {
if d.dm.byteStringToString == ByteStringToStringAllowedWithExpectedLaterEncoding || d.dm.byteStringExpectedFormat != ByteStringExpectedFormatNone {
d.expectedLaterEncodingTags = append(d.expectedLaterEncodingTags, tagNum)
defer func() {
d.expectedLaterEncodingTags = d.expectedLaterEncodingTags[:len(d.expectedLaterEncodingTags)-1]
Expand Down Expand Up @@ -1923,7 +1965,8 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
case tagNumExpectedLaterEncodingBase64URL, tagNumExpectedLaterEncodingBase64, tagNumExpectedLaterEncodingBase16:
// If conversion for interoperability with text encodings is not configured,
// treat tags 21-23 as unregistered tags.
if d.dm.byteStringToString == ByteStringToStringAllowedWithExpectedLaterEncoding || d.dm.byteSliceExpectedEncoding != ByteSliceExpectedEncodingIgnored {
if d.dm.byteStringToString == ByteStringToStringAllowedWithExpectedLaterEncoding ||
d.dm.byteStringExpectedFormat != ByteStringExpectedFormatNone {
d.expectedLaterEncodingTags = append(d.expectedLaterEncodingTags, tagNum)
defer func() {
d.expectedLaterEncodingTags = d.expectedLaterEncodingTags[:len(d.expectedLaterEncodingTags)-1]
Expand Down Expand Up @@ -2082,28 +2125,28 @@ func (d *decoder) applyByteStringTextConversion(
return src, false, nil
}

switch d.dm.byteSliceExpectedEncoding {
case ByteSliceExpectedEncodingBase64URL:
switch d.dm.byteStringExpectedFormat {
case ByteStringExpectedBase64URL:
decoded := make([]byte, base64.RawURLEncoding.DecodedLen(len(src)))
n, err := base64.RawURLEncoding.Decode(decoded, src)
if err != nil {
return nil, false, fmt.Errorf("cbor: failed to decode base64url string: %v", err)
return nil, false, newByteStringExpectedFormatError(ByteStringExpectedBase64URL, err)
}
return decoded[:n], true, nil

case ByteSliceExpectedEncodingBase64:
case ByteStringExpectedBase64:
decoded := make([]byte, base64.StdEncoding.DecodedLen(len(src)))
n, err := base64.StdEncoding.Decode(decoded, src)
if err != nil {
return nil, false, fmt.Errorf("cbor: failed to decode base64 string: %v", err)
return nil, false, newByteStringExpectedFormatError(ByteStringExpectedBase64, err)
}
return decoded[:n], true, nil

case ByteSliceExpectedEncodingBase16:
case ByteStringExpectedBase16:
decoded := make([]byte, hex.DecodedLen(len(src)))
n, err := hex.Decode(decoded, src)
if err != nil {
return nil, false, fmt.Errorf("cbor: failed to decode hex string: %v", err)
return nil, false, newByteStringExpectedFormatError(ByteStringExpectedBase16, err)
}
return decoded[:n], true, nil
}
Expand Down
Loading
Loading