Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to specify buffer by adding cbor.MarshalToBuffer(), UserBufferEncMode interface, etc. #553

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"math/big"
Expand Down Expand Up @@ -95,6 +96,17 @@ func Marshal(v interface{}) ([]byte, error) {
return defaultEncMode.Marshal(v)
}

// MarshalToBuffer encodes v into provided buffer (instead of using built-in buffer pool)
// and uses default encoding options.
//
// NOTE: Unlike Marshal, the buffer provided to MarshalToBuffer can contain
// partially encoded data if error is returned.
//
// See Marshal for more details.
func MarshalToBuffer(v interface{}, buf *bytes.Buffer) error {
return defaultEncMode.MarshalToBuffer(v, buf)
}

// Marshaler is the interface implemented by types that can marshal themselves
// into valid CBOR.
type Marshaler interface {
Expand Down Expand Up @@ -617,8 +629,18 @@ func (opts EncOptions) EncMode() (EncMode, error) { //nolint:gocritic // ignore
return opts.encMode()
}

// UserBufferEncMode returns UserBufferEncMode with immutable options and no tags (safe for concurrency).
func (opts EncOptions) UserBufferEncMode() (UserBufferEncMode, error) { //nolint:gocritic // ignore hugeParam
return opts.encMode()
}

// EncModeWithTags returns EncMode with options and tags that are both immutable (safe for concurrency).
func (opts EncOptions) EncModeWithTags(tags TagSet) (EncMode, error) { //nolint:gocritic // ignore hugeParam
return opts.UserBufferEncModeWithTags(tags)
}

// UserBufferEncModeWithTags returns UserBufferEncMode with options and tags that are both immutable (safe for concurrency).
func (opts EncOptions) UserBufferEncModeWithTags(tags TagSet) (UserBufferEncMode, error) { //nolint:gocritic // ignore hugeParam
if opts.TagsMd == TagsForbidden {
return nil, errors.New("cbor: cannot create EncMode with TagSet when TagsMd is TagsForbidden")
}
Expand Down Expand Up @@ -647,6 +669,11 @@ func (opts EncOptions) EncModeWithTags(tags TagSet) (EncMode, error) { //nolint:

// EncModeWithSharedTags returns EncMode with immutable options and mutable shared tags (safe for concurrency).
func (opts EncOptions) EncModeWithSharedTags(tags TagSet) (EncMode, error) { //nolint:gocritic // ignore hugeParam
return opts.UserBufferEncModeWithSharedTags(tags)
}

// UserBufferEncModeWithSharedTags returns UserBufferEncMode with immutable options and mutable shared tags (safe for concurrency).
func (opts EncOptions) UserBufferEncModeWithSharedTags(tags TagSet) (UserBufferEncMode, error) { //nolint:gocritic // ignore hugeParam
if opts.TagsMd == TagsForbidden {
return nil, errors.New("cbor: cannot create EncMode with TagSet when TagsMd is TagsForbidden")
}
Expand Down Expand Up @@ -745,6 +772,20 @@ type EncMode interface {
EncOptions() EncOptions
}

// UserBufferEncMode is an interface for CBOR encoding, which extends EncMode by
// adding MarshalToBuffer to support user specified buffer rather than encoding
// into the built-in buffer pool.
type UserBufferEncMode interface {
EncMode
MarshalToBuffer(v interface{}, buf *bytes.Buffer) error

// This private method is to prevent users implementing
// this interface and so future additions to it will
// not be breaking changes.
// See https://go.dev/blog/module-compatibility
unexport()
}

type encMode struct {
tags tagProvider
sort SortMode
Expand Down Expand Up @@ -860,6 +901,8 @@ func (em *encMode) EncOptions() EncOptions {
}
}

func (em *encMode) unexport() {}

func (em *encMode) encTagBytes(t reflect.Type) []byte {
if em.tags != nil {
if tagItem := em.tags.getTagItemFromType(t); tagItem != nil {
Expand Down Expand Up @@ -887,7 +930,17 @@ func (em *encMode) Marshal(v interface{}) ([]byte, error) {
return buf, nil
}

// MarshalToBuffer encodes v into provided buffer (instead of using built-in buffer pool)
// and uses em encoding mode.
//
// NOTE: Unlike Marshal, the buffer provided to MarshalToBuffer can contain
// partially encoded data if error is returned.
//
// See Marshal for more details.
func (em *encMode) MarshalToBuffer(v interface{}, buf *bytes.Buffer) error {
if buf == nil {
return fmt.Errorf("cbor: encoding buffer provided by user is nil")
}
return encode(buf, em, reflect.ValueOf(v))
}

Expand Down
21 changes: 21 additions & 0 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -498,16 +498,37 @@ func testMarshal(t *testing.T, testCases []marshalTest) {
if err != nil {
t.Errorf("EncMode() returned an error %v", err)
}
bem, err := EncOptions{Sort: SortCanonical}.UserBufferEncMode()
if err != nil {
t.Errorf("UserBufferEncMode() returned an error %v", err)
}
for _, tc := range testCases {
for _, value := range tc.values {
// Encode value using default options
if _, err := Marshal(value); err != nil {
t.Errorf("Marshal(%v) returned error %v", value, err)
}

// Encode value to provided buffer using default options
var buf1 bytes.Buffer
if err := MarshalToBuffer(value, &buf1); err != nil {
t.Errorf("MarshalToBuffer(%v) returned error %v", value, err)
}

// Encode value using specified options
if b, err := em.Marshal(value); err != nil {
t.Errorf("Marshal(%v) returned error %v", value, err)
} else if !bytes.Equal(b, tc.wantData) {
t.Errorf("Marshal(%v) = 0x%x, want 0x%x", value, b, tc.wantData)
}

// Encode value to provided buffer using specified options
var buf2 bytes.Buffer
if err := bem.MarshalToBuffer(value, &buf2); err != nil {
t.Errorf("MarshalToBuffer(%v) returned error %v", value, err)
} else if !bytes.Equal(buf2.Bytes(), tc.wantData) {
t.Errorf("Marshal(%v) = 0x%x, want 0x%x", value, buf2.Bytes(), tc.wantData)
}
}
r := RawMessage(tc.wantData)
if b, err := Marshal(r); err != nil {
Expand Down
Loading