From fdf5bd837811dc55be58ae31843f0d4d0d40511a Mon Sep 17 00:00:00 2001 From: Faye Amacker <33205765+fxamacker@users.noreply.github.com> Date: Mon, 27 May 2024 21:14:04 -0500 Subject: [PATCH] Refactor to remove more magic numbers --- common.go | 14 ++++++++++++-- decode.go | 36 +++++++++++++++++++++++++++++++++++ diagnose.go | 18 ++++++++++-------- encode.go | 37 +++++++++++++++++++++++++++++------- encode_test.go | 51 ++++++++++++++++++++++++++++++++++---------------- simplevalue.go | 8 ++++---- stream.go | 8 ++++---- 7 files changed, 131 insertions(+), 41 deletions(-) diff --git a/common.go b/common.go index 4e11f0f5..2793ba3f 100644 --- a/common.go +++ b/common.go @@ -53,7 +53,7 @@ const ( additionalInformationWith4ByteArgument = 26 additionalInformationWith8ByteArgument = 27 - // additional information with major type 7 + // For major type 7. additionalInformationAsFalse = 20 additionalInformationAsTrue = 21 additionalInformationAsNull = 22 @@ -62,9 +62,15 @@ const ( additionalInformationAsFloat32 = 26 additionalInformationAsFloat64 = 27 + // For major type 2, 3, 4, 5. additionalInformationAsIndefiniteLengthFlag = 31 ) +const ( + maxSimpleValueInAdditionalInformation = 23 + minSimpleValueIn1ByteArgument = 32 +) + func (ai additionalInformation) isIndefiniteLength() bool { return ai == additionalInformationAsIndefiniteLengthFlag } @@ -110,7 +116,11 @@ const ( ) const ( - cborBreakFlag = byte(0xff) + cborBreakFlag = byte(0xff) + cborByteStringWithIndefiniteLengthHead = byte(0x5f) + cborTextStringWithIndefiniteLengthHead = byte(0x7f) + cborArrayWithIndefiniteLengthHead = byte(0x9f) + cborMapWithIndefiniteLengthHead = byte(0xbf) ) var ( diff --git a/decode.go b/decode.go index ea943977..5449d533 100644 --- a/decode.go +++ b/decode.go @@ -1358,8 +1358,10 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin v.Set(reflect.ValueOf(iv)) } return err + case specialTypeTag: return d.parseToTag(v) + case specialTypeTime: if d.nextCBORNil() { // Decoding CBOR null and undefined to time.Time is no-op. @@ -1374,6 +1376,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin v.Set(reflect.ValueOf(tm)) } return nil + case specialTypeUnmarshalerIface: return d.parseToUnmarshaler(v) } @@ -1535,6 +1538,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin }() } } + return d.parseToValue(v, tInfo) case cborTypeArray: @@ -1628,6 +1632,7 @@ func (d *decoder) parseToTime() (time.Time, bool, error) { return t, true, nil } return time.Time{}, false, &UnmarshalTypeError{CBORType: t.String(), GoType: typeTime.String()} + case cborTypeTextString: s, err := d.parseTextString() if err != nil { @@ -1638,6 +1643,7 @@ func (d *decoder) parseToTime() (time.Time, bool, error) { return time.Time{}, false, errors.New("cbor: cannot set " + string(s) + " for time.Time: " + err.Error()) } return t, true, nil + case cborTypePositiveInt: _, _, val := d.getHead() if val > math.MaxInt64 { @@ -1648,6 +1654,7 @@ func (d *decoder) parseToTime() (time.Time, bool, error) { } } return time.Unix(int64(val), 0), true, nil + case cborTypeNegativeInt: _, _, val := d.getHead() if val > math.MaxInt64 { @@ -1667,6 +1674,7 @@ func (d *decoder) parseToTime() (time.Time, bool, error) { } } return time.Unix(int64(-1)^int64(val), 0), true, nil + case cborTypePrimitives: _, ai, val := d.getHead() var f float64 @@ -1690,6 +1698,7 @@ func (d *decoder) parseToTime() (time.Time, bool, error) { } seconds, fractional := math.Modf(f) return time.Unix(int64(seconds), int64(fractional*1e9)), true, nil + default: return time.Time{}, false, &UnmarshalTypeError{CBORType: t.String(), GoType: typeTime.String()} } @@ -1822,8 +1831,10 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli clone := make([]byte, len(b)) copy(clone, b) return clone, nil + case typeString: return string(b), nil + default: if copied || d.dm.defaultByteStringType.Kind() == reflect.String { // Avoid an unnecessary copy since the conversion to string must @@ -1834,12 +1845,14 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli copy(clone, b) return reflect.ValueOf(clone).Convert(d.dm.defaultByteStringType).Interface(), nil } + case cborTypeTextString: b, err := d.parseTextString() if err != nil { return nil, err } return string(b), nil + case cborTypeTag: tagOff := d.off _, _, tagNum := d.getHead() @@ -1852,9 +1865,11 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli if err != nil { return nil, err } + switch d.dm.timeTagToAny { case TimeTagToTime: return tm, nil + case TimeTagToRFC3339: if tagNum == 1 { tm = tm.UTC() @@ -1866,6 +1881,7 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli return nil, err } return string(text), nil + case TimeTagToRFC3339Nano: if tagNum == 1 { tm = tm.UTC() @@ -1877,6 +1893,7 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli return nil, err } return string(text), nil + default: // not reachable } @@ -1953,6 +1970,7 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli if ai < 20 || ai == 24 { return SimpleValue(val), nil } + switch ai { case additionalInformationAsFalse, additionalInformationAsTrue: @@ -1977,6 +1995,7 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli case cborTypeArray: return d.parseArray() + case cborTypeMap: if d.dm.defaultMapType != nil { m := reflect.New(d.dm.defaultMapType) @@ -1988,6 +2007,7 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli } return d.parseMap() } + return nil, nil } @@ -2035,19 +2055,23 @@ func (d *decoder) applyByteStringTextConversion( encoded := make([]byte, base64.RawURLEncoding.EncodedLen(len(src))) base64.RawURLEncoding.Encode(encoded, src) return encoded, true, nil + case tagNumExpectedLaterEncodingBase64: encoded := make([]byte, base64.StdEncoding.EncodedLen(len(src))) base64.StdEncoding.Encode(encoded, src) return encoded, true, nil + case tagNumExpectedLaterEncodingBase16: encoded := make([]byte, hex.EncodedLen(len(src))) hex.Encode(encoded, src) return encoded, true, nil + default: // If this happens, there is a bug: the decoder has pushed an invalid // "expected later encoding" tag to the stack. panic(fmt.Sprintf("unrecognized expected later encoding tag: %d", d.expectedLaterEncodingTags)) } + case reflect.Slice: if dstType.Elem().Kind() != reflect.Uint8 || len(d.expectedLaterEncodingTags) > 0 { // Either the destination is not a slice of bytes, or the encoder that @@ -2064,6 +2088,7 @@ func (d *decoder) applyByteStringTextConversion( return nil, false, fmt.Errorf("cbor: failed to decode base64url string: %v", err) } return decoded[:n], true, nil + case ByteSliceExpectedEncodingBase64: decoded := make([]byte, base64.StdEncoding.DecodedLen(len(src))) n, err := base64.StdEncoding.Decode(decoded, src) @@ -2071,6 +2096,7 @@ func (d *decoder) applyByteStringTextConversion( return nil, false, fmt.Errorf("cbor: failed to decode base64 string: %v", err) } return decoded[:n], true, nil + case ByteSliceExpectedEncodingBase16: decoded := make([]byte, hex.DecodedLen(len(src))) n, err := hex.Decode(decoded, src) @@ -2756,14 +2782,17 @@ func (d *decoder) skip() { switch t { case cborTypeByteString, cborTypeTextString: d.off += int(val) + case cborTypeArray: for i := 0; i < int(val); i++ { d.skip() } + case cborTypeMap: for i := 0; i < int(val)*2; i++ { d.skip() } + case cborTypeTag: d.skip() } @@ -2893,6 +2922,7 @@ func fillPositiveInt(t cborType, val uint64, v reflect.Value) error { } v.SetInt(int64(val)) return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if v.OverflowUint(val) { return &UnmarshalTypeError{ @@ -2903,11 +2933,13 @@ func fillPositiveInt(t cborType, val uint64, v reflect.Value) error { } v.SetUint(val) return nil + case reflect.Float32, reflect.Float64: f := float64(val) v.SetFloat(f) return nil } + if v.Type() == typeBigInt { i := new(big.Int).SetUint64(val) v.Set(reflect.ValueOf(*i)) @@ -2928,6 +2960,7 @@ func fillNegativeInt(t cborType, val int64, v reflect.Value) error { } v.SetInt(val) return nil + case reflect.Float32, reflect.Float64: f := float64(val) v.SetFloat(f) @@ -3026,6 +3059,7 @@ func isImmutableKind(k reflect.Kind) bool { reflect.Float32, reflect.Float64, reflect.String: return true + default: return false } @@ -3035,6 +3069,7 @@ func isHashableValue(rv reflect.Value) bool { switch rv.Kind() { case reflect.Slice, reflect.Map, reflect.Func: return false + case reflect.Struct: switch rv.Type() { case typeTag: @@ -3057,6 +3092,7 @@ func convertByteSliceToByteString(v interface{}) (interface{}, bool) { switch v := v.(type) { case []byte: return ByteString(v), true + case Tag: content, converted := convertByteSliceToByteString(v.Content) if converted { diff --git a/diagnose.go b/diagnose.go index 05850f1a..2794df3f 100644 --- a/diagnose.go +++ b/diagnose.go @@ -242,16 +242,17 @@ func (di *diagnose) wellformed(allowExtraData bool) error { func (di *diagnose) item() error { //nolint:gocyclo initialByte := di.d.data[di.d.off] switch initialByte { - case 0x5f, 0x7f: // indefinite-length byte/text string + case cborByteStringWithIndefiniteLengthHead, + cborTextStringWithIndefiniteLengthHead: // indefinite-length byte/text string di.d.off++ if isBreakFlag(di.d.data[di.d.off]) { di.d.off++ switch initialByte { - case 0x5f: + case cborByteStringWithIndefiniteLengthHead: // indefinite-length bytes with no chunks. di.w.WriteString(`''_`) return nil - case 0x7f: + case cborTextStringWithIndefiniteLengthHead: // indefinite-length text with no chunks. di.w.WriteString(`""_`) return nil @@ -276,7 +277,7 @@ func (di *diagnose) item() error { //nolint:gocyclo di.w.WriteByte(')') return nil - case 0x9f: // indefinite-length array + case cborArrayWithIndefiniteLengthHead: // indefinite-length array di.d.off++ di.w.WriteString("[_ ") @@ -295,7 +296,7 @@ func (di *diagnose) item() error { //nolint:gocyclo di.w.WriteByte(']') return nil - case 0xbf: // indefinite-length map + case cborMapWithIndefiniteLengthHead: // indefinite-length map di.d.off++ di.w.WriteString("{_ ") @@ -573,7 +574,7 @@ func (di *diagnose) encodeByteString(val []byte) error { } } -var utf16SurrSelf = rune(0x10000) +const utf16SurrSelf = rune(0x10000) // quote should be either `'` or `"` func (di *diagnose) encodeTextString(val string, quote byte) error { @@ -678,16 +679,17 @@ func (di *diagnose) encodeFloat(ai byte, val uint64) error { } // Use ES6 number to string conversion which should match most JSON generators. // Inspired by https://github.com/golang/go/blob/4df10fba1687a6d4f51d7238a403f8f2298f6a16/src/encoding/json/encode.go#L585 + const bitSize = 64 b := make([]byte, 0, 32) if abs := math.Abs(f64); abs != 0 && (abs < 1e-6 || abs >= 1e21) { - b = strconv.AppendFloat(b, f64, 'e', -1, 64) + b = strconv.AppendFloat(b, f64, 'e', -1, bitSize) // clean up e-09 to e-9 n := len(b) if n >= 4 && string(b[n-4:n-1]) == "e-0" { b = append(b[:n-2], b[n-1]) } } else { - b = strconv.AppendFloat(b, f64, 'f', -1, 64) + b = strconv.AppendFloat(b, f64, 'f', -1, bitSize) } // add decimal point and trailing zero if needed diff --git a/encode.go b/encode.go index 9c79d379..25dbf6e9 100644 --- a/encode.go +++ b/encode.go @@ -195,6 +195,7 @@ func (st StringMode) cborType() (cborType, error) { switch st { case StringToTextString: return cborTypeTextString, nil + case StringToByteString: return cborTypeByteString, nil } @@ -417,10 +418,13 @@ func (bsm ByteSliceMode) encodingTag() (uint64, error) { switch bsm { case ByteSliceToByteString: return 0, nil + case ByteSliceToByteStringWithExpectedConversionToBase64URL: return tagNumExpectedLaterEncodingBase64URL, nil + case ByteSliceToByteStringWithExpectedConversionToBase64: return tagNumExpectedLaterEncodingBase64, nil + case ByteSliceToByteStringWithExpectedConversionToBase16: return tagNumExpectedLaterEncodingBase16, nil } @@ -978,9 +982,9 @@ func encodeFloat(e *bytes.Buffer, em *encMode, v reflect.Value) error { // Encode float64 // Don't use encodeFloat64() because it cannot be inlined. var scratch [9]byte - scratch[0] = byte(cborTypePrimitives) | byte(27) + scratch[0] = byte(cborTypePrimitives) | byte(additionalInformationAsFloat64) binary.BigEndian.PutUint64(scratch[1:], math.Float64bits(f64)) - e.Write(scratch[:9]) + e.Write(scratch[:]) return nil } @@ -1002,7 +1006,7 @@ func encodeFloat(e *bytes.Buffer, em *encMode, v reflect.Value) error { // Encode float16 // Don't use encodeFloat16() because it cannot be inlined. var scratch [3]byte - scratch[0] = byte(cborTypePrimitives) | byte(25) + scratch[0] = byte(cborTypePrimitives) | additionalInformationAsFloat16 binary.BigEndian.PutUint16(scratch[1:], uint16(f16)) e.Write(scratch[:3]) return nil @@ -1012,7 +1016,7 @@ func encodeFloat(e *bytes.Buffer, em *encMode, v reflect.Value) error { // Encode float32 // Don't use encodeFloat32() because it cannot be inlined. var scratch [5]byte - scratch[0] = byte(cborTypePrimitives) | byte(26) + scratch[0] = byte(cborTypePrimitives) | additionalInformationAsFloat32 binary.BigEndian.PutUint32(scratch[1:], math.Float32bits(f32)) e.Write(scratch[:5]) return nil @@ -1023,6 +1027,7 @@ func encodeInf(e *bytes.Buffer, em *encMode, v reflect.Value) error { switch em.infConvert { case InfConvertReject: return &UnsupportedValueError{msg: "floating-point infinity"} + case InfConvertFloat16: if f64 > 0 { e.Write(cborPositiveInfinity) @@ -1100,7 +1105,7 @@ func encodeNaN(e *bytes.Buffer, em *encMode, v reflect.Value) error { func encodeFloat16(e *bytes.Buffer, f16 float16.Float16) error { var scratch [3]byte - scratch[0] = byte(cborTypePrimitives) | byte(25) + scratch[0] = byte(cborTypePrimitives) | additionalInformationAsFloat16 binary.BigEndian.PutUint16(scratch[1:], uint16(f16)) e.Write(scratch[:3]) return nil @@ -1108,7 +1113,7 @@ func encodeFloat16(e *bytes.Buffer, f16 float16.Float16) error { func encodeFloat32(e *bytes.Buffer, f32 float32) error { var scratch [5]byte - scratch[0] = byte(cborTypePrimitives) | byte(26) + scratch[0] = byte(cborTypePrimitives) | additionalInformationAsFloat32 binary.BigEndian.PutUint32(scratch[1:], math.Float32bits(f32)) e.Write(scratch[:5]) return nil @@ -1116,7 +1121,7 @@ func encodeFloat32(e *bytes.Buffer, f32 float32) error { func encodeFloat64(e *bytes.Buffer, f64 float64) error { var scratch [9]byte - scratch[0] = byte(cborTypePrimitives) | byte(27) + scratch[0] = byte(cborTypePrimitives) | additionalInformationAsFloat64 binary.BigEndian.PutUint64(scratch[1:], math.Float64bits(f64)) e.Write(scratch[:9]) return nil @@ -1478,10 +1483,12 @@ func encodeTime(e *bytes.Buffer, em *encMode, v reflect.Value) error { case TimeUnix: secs := t.Unix() return encodeInt(e, em, reflect.ValueOf(secs)) + case TimeUnixMicro: t = t.UTC().Round(time.Microsecond) f := float64(t.UnixNano()) / 1e9 return encodeFloat(e, em, reflect.ValueOf(f)) + case TimeUnixDynamic: t = t.UTC().Round(time.Microsecond) secs, nsecs := t.Unix(), uint64(t.Nanosecond()) @@ -1490,9 +1497,11 @@ func encodeTime(e *bytes.Buffer, em *encMode, v reflect.Value) error { } f := float64(secs) + float64(nsecs)/1e9 return encodeFloat(e, em, reflect.ValueOf(f)) + case TimeRFC3339: s := t.Format(time.RFC3339) return encodeString(e, em, reflect.ValueOf(s)) + default: // TimeRFC3339Nano s := t.Format(time.RFC3339Nano) return encodeString(e, em, reflect.ValueOf(s)) @@ -1690,14 +1699,19 @@ func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc) { switch t { case typeSimpleValue: return encodeMarshalerType, isEmptyUint + case typeTag: return encodeTag, alwaysNotEmpty + case typeTime: return encodeTime, alwaysNotEmpty + case typeBigInt: return encodeBigInt, alwaysNotEmpty + case typeRawMessage: return encodeMarshalerType, isEmptySlice + case typeByteString: return encodeMarshalerType, isEmptyString } @@ -1718,31 +1732,39 @@ func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc) { switch k { case reflect.Bool: return encodeBool, isEmptyBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return encodeInt, isEmptyInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return encodeUint, isEmptyUint + case reflect.Float32, reflect.Float64: return encodeFloat, isEmptyFloat + case reflect.String: return encodeString, isEmptyString + case reflect.Slice: if t.Elem().Kind() == reflect.Uint8 { return encodeByteString, isEmptySlice } fallthrough + case reflect.Array: f, _ := getEncodeFunc(t.Elem()) if f == nil { return nil, nil } return arrayEncodeFunc{f: f}.encode, isEmptySlice + case reflect.Map: f := getEncodeMapFunc(t) if f == nil { return nil, nil } return f, isEmptyMap + case reflect.Struct: // Get struct's special field "_" tag options if f, ok := t.FieldByName("_"); ok { @@ -1754,6 +1776,7 @@ func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc) { } } return encodeStruct, isEmptyStruct + case reflect.Interface: return encodeIntf, isEmptyIntf } diff --git a/encode_test.go b/encode_test.go index 463bef88..4b7cc826 100644 --- a/encode_test.go +++ b/encode_test.go @@ -449,27 +449,46 @@ func TestMarshalLargeMap(t *testing.T) { } func encodeCborHeader(t cborType, n uint64) []byte { - b := make([]byte, 9) - if n <= 23 { + if n <= maxAdditionalInformationWithoutArgument { + const headSize = 1 + var b [headSize]byte b[0] = byte(t) | byte(n) - return b[:1] - } else if n <= math.MaxUint8 { - b[0] = byte(t) | byte(24) + return b[:] + } + + if n <= math.MaxUint8 { + const argumentSize = 1 + const headSize = 1 + argumentSize + var b [headSize]byte + b[0] = byte(t) | additionalInformationWith1ByteArgument b[1] = byte(n) - return b[:2] - } else if n <= math.MaxUint16 { - b[0] = byte(t) | byte(25) + return b[:] + } + + if n <= math.MaxUint16 { + const argumentSize = 2 + const headSize = 1 + argumentSize + var b [headSize]byte + b[0] = byte(t) | additionalInformationWith2ByteArgument binary.BigEndian.PutUint16(b[1:], uint16(n)) - return b[:3] - } else if n <= math.MaxUint32 { - b[0] = byte(t) | byte(26) + return b[:] + } + + if n <= math.MaxUint32 { + const argumentSize = 4 + const headSize = 1 + argumentSize + var b [headSize]byte + b[0] = byte(t) | additionalInformationWith4ByteArgument binary.BigEndian.PutUint32(b[1:], uint32(n)) - return b[:5] - } else { - b[0] = byte(t) | byte(27) - binary.BigEndian.PutUint64(b[1:], n) - return b[:9] + return b[:] } + + const argumentSize = 8 + const headSize = 1 + argumentSize + var b [headSize]byte + b[0] = byte(t) | additionalInformationWith8ByteArgument + binary.BigEndian.PutUint64(b[1:], n) + return b[:] } func testMarshal(t *testing.T, testCases []marshalTest) { diff --git a/simplevalue.go b/simplevalue.go index 6f93f67c..de175cee 100644 --- a/simplevalue.go +++ b/simplevalue.go @@ -33,11 +33,11 @@ func (sv SimpleValue) MarshalCBOR() ([]byte, error) { // only has a single representation variant)." switch { - case sv <= 23: + case sv <= maxSimpleValueInAdditionalInformation: return []byte{byte(cborTypePrimitives) | byte(sv)}, nil - case sv >= 32: - return []byte{byte(cborTypePrimitives) | byte(24), byte(sv)}, nil + case sv >= minSimpleValueIn1ByteArgument: + return []byte{byte(cborTypePrimitives) | additionalInformationWith1ByteArgument, byte(sv)}, nil default: return nil, &UnsupportedValueError{msg: fmt.Sprintf("SimpleValue(%d)", sv)} @@ -57,7 +57,7 @@ func (sv *SimpleValue) UnmarshalCBOR(data []byte) error { if typ != cborTypePrimitives { return &UnmarshalTypeError{CBORType: typ.String(), GoType: "SimpleValue"} } - if ai > 24 { + if ai > additionalInformationWith1ByteArgument { return &UnmarshalTypeError{CBORType: typ.String(), GoType: "SimpleValue", errorMsg: "not simple values"} } diff --git a/stream.go b/stream.go index dcb60b44..507ab6c1 100644 --- a/stream.go +++ b/stream.go @@ -239,10 +239,10 @@ func (enc *Encoder) EndIndefinite() error { } var cborIndefHeader = map[cborType][]byte{ - cborTypeByteString: {0x5f}, - cborTypeTextString: {0x7f}, - cborTypeArray: {0x9f}, - cborTypeMap: {0xbf}, + cborTypeByteString: {cborByteStringWithIndefiniteLengthHead}, + cborTypeTextString: {cborTextStringWithIndefiniteLengthHead}, + cborTypeArray: {cborArrayWithIndefiniteLengthHead}, + cborTypeMap: {cborMapWithIndefiniteLengthHead}, } func (enc *Encoder) startIndefinite(typ cborType) error {