Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supports dynamically switch encode and decode processing for a given type #368

Merged
merged 1 commit into from
Apr 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type Decoder struct {
referenceReaders []io.Reader
anchorNodeMap map[string]ast.Node
anchorValueMap map[string]reflect.Value
customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error
toCommentMap CommentMap
opts []DecodeOption
referenceFiles []string
Expand All @@ -50,6 +51,7 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
reader: r,
anchorNodeMap: map[string]ast.Node{},
anchorValueMap: map[string]reflect.Value{},
customUnmarshalerMap: map[reflect.Type]func(interface{}, []byte) error{},
opts: opts,
referenceReaders: []io.Reader{},
referenceFiles: []string{},
Expand Down Expand Up @@ -638,8 +640,38 @@ type jsonUnmarshaler interface {
UnmarshalJSON([]byte) error
}

func (d *Decoder) existsTypeInCustomUnmarshalerMap(t reflect.Type) bool {
if _, exists := d.customUnmarshalerMap[t]; exists {
return true
}

globalCustomUnmarshalerMu.Lock()
defer globalCustomUnmarshalerMu.Unlock()
if _, exists := globalCustomUnmarshalerMap[t]; exists {
return true
}
return false
}

func (d *Decoder) unmarshalerFromCustomUnmarshalerMap(t reflect.Type) (func(interface{}, []byte) error, bool) {
if unmarshaler, exists := d.customUnmarshalerMap[t]; exists {
return unmarshaler, exists
}

globalCustomUnmarshalerMu.Lock()
defer globalCustomUnmarshalerMu.Unlock()
if unmarshaler, exists := globalCustomUnmarshalerMap[t]; exists {
return unmarshaler, exists
}
return nil, false
}

func (d *Decoder) canDecodeByUnmarshaler(dst reflect.Value) bool {
iface := dst.Addr().Interface()
ptrValue := dst.Addr()
if d.existsTypeInCustomUnmarshalerMap(ptrValue.Type()) {
return true
}
iface := ptrValue.Interface()
switch iface.(type) {
case BytesUnmarshalerContext:
return true
Expand All @@ -662,7 +694,18 @@ func (d *Decoder) canDecodeByUnmarshaler(dst reflect.Value) bool {
}

func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, src ast.Node) error {
iface := dst.Addr().Interface()
ptrValue := dst.Addr()
if unmarshaler, exists := d.unmarshalerFromCustomUnmarshalerMap(ptrValue.Type()); exists {
b, err := d.unmarshalableDocument(src)
if err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
if err := unmarshaler(ptrValue.Interface(), b); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
}
iface := ptrValue.Interface()

if unmarshaler, ok := iface.(BytesUnmarshalerContext); ok {
b, err := d.unmarshalableDocument(src)
Expand Down
48 changes: 48 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,54 @@ func TestDecoder_UseJSONUnmarshaler(t *testing.T) {
}
}

func TestDecoder_CustomUnmarshaler(t *testing.T) {
t.Run("override struct type", func(t *testing.T) {
type T struct {
Foo string `yaml:"foo"`
}
src := []byte(`foo: "bar"`)
var v T
if err := yaml.UnmarshalWithOptions(src, &v, yaml.CustomUnmarshaler[T](func(dst *T, b []byte) error {
if !bytes.Equal(src, b) {
t.Fatalf("failed to get decode target buffer. expected %q but got %q", src, b)
}
var v T
if err := yaml.Unmarshal(b, &v); err != nil {
return err
}
if v.Foo != "bar" {
t.Fatal("failed to decode")
}
dst.Foo = "bazbaz" // assign another value to target
return nil
})); err != nil {
t.Fatal(err)
}
if v.Foo != "bazbaz" {
t.Fatalf("failed to switch to custom unmarshaler. got: %v", v.Foo)
}
})
t.Run("override bytes type", func(t *testing.T) {
type T struct {
Foo []byte `yaml:"foo"`
}
src := []byte(`foo: "bar"`)
var v T
if err := yaml.UnmarshalWithOptions(src, &v, yaml.CustomUnmarshaler[[]byte](func(dst *[]byte, b []byte) error {
if !bytes.Equal(b, []byte(`"bar"`)) {
t.Fatalf("failed to get target buffer: %q", b)
}
*dst = []byte("bazbaz")
return nil
})); err != nil {
t.Fatal(err)
}
if !bytes.Equal(v.Foo, []byte("bazbaz")) {
t.Fatalf("failed to switch to custom unmarshaler. got: %q", v.Foo)
}
})
}

type unmarshalContext struct {
v int
}
Expand Down
43 changes: 43 additions & 0 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Encoder struct {
useJSONMarshaler bool
anchorCallback func(*ast.AnchorNode, interface{}) error
anchorPtrToNameMap map[uintptr]string
customMarshalerMap map[reflect.Type]func(interface{}) ([]byte, error)
useLiteralStyleIfMultiline bool
commentMap map[*Path][]*Comment
written bool
Expand All @@ -56,6 +57,7 @@ func NewEncoder(w io.Writer, opts ...EncodeOption) *Encoder {
opts: opts,
indent: DefaultIndentSpaces,
anchorPtrToNameMap: map[uintptr]string{},
customMarshalerMap: map[reflect.Type]func(interface{}) ([]byte, error){},
line: 1,
column: 1,
offset: 0,
Expand Down Expand Up @@ -273,10 +275,39 @@ type jsonMarshaler interface {
MarshalJSON() ([]byte, error)
}

func (e *Encoder) existsTypeInCustomMarshalerMap(t reflect.Type) bool {
if _, exists := e.customMarshalerMap[t]; exists {
return true
}

globalCustomMarshalerMu.Lock()
defer globalCustomMarshalerMu.Unlock()
if _, exists := globalCustomMarshalerMap[t]; exists {
return true
}
return false
}

func (e *Encoder) marshalerFromCustomMarshalerMap(t reflect.Type) (func(interface{}) ([]byte, error), bool) {
if marshaler, exists := e.customMarshalerMap[t]; exists {
return marshaler, exists
}

globalCustomMarshalerMu.Lock()
defer globalCustomMarshalerMu.Unlock()
if marshaler, exists := globalCustomMarshalerMap[t]; exists {
return marshaler, exists
}
return nil, false
}

func (e *Encoder) canEncodeByMarshaler(v reflect.Value) bool {
if !v.CanInterface() {
return false
}
if e.existsTypeInCustomMarshalerMap(v.Type()) {
return true
}
iface := v.Interface()
switch iface.(type) {
case BytesMarshalerContext:
Expand All @@ -302,6 +333,18 @@ func (e *Encoder) canEncodeByMarshaler(v reflect.Value) bool {
func (e *Encoder) encodeByMarshaler(ctx context.Context, v reflect.Value, column int) (ast.Node, error) {
iface := v.Interface()

if marshaler, exists := e.marshalerFromCustomMarshalerMap(v.Type()); exists {
doc, err := marshaler(iface)
if err != nil {
return nil, errors.Wrapf(err, "failed to MarshalYAML")
}
node, err := e.encodeDocument(doc)
if err != nil {
return nil, errors.Wrapf(err, "failed to encode document")
}
return node, nil
}

if marshaler, ok := iface.(BytesMarshalerContext); ok {
doc, err := marshaler.MarshalYAML(ctx)
if err != nil {
Expand Down
37 changes: 36 additions & 1 deletion encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import (
"bytes"
"context"
"fmt"
"github.com/goccy/go-yaml/parser"
"math"
"reflect"
"strconv"
"testing"
"time"

"github.com/goccy/go-yaml/parser"

"github.com/goccy/go-yaml"
"github.com/goccy/go-yaml/ast"
)
Expand Down Expand Up @@ -1177,6 +1178,40 @@ a:
}
}

func TestEncoder_CustomMarshaler(t *testing.T) {
t.Run("override struct type", func(t *testing.T) {
type T struct {
Foo string `yaml:"foo"`
}
b, err := yaml.MarshalWithOptions(&T{Foo: "bar"}, yaml.CustomMarshaler[T](func(v T) ([]byte, error) {
return []byte(`"override"`), nil
}))
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(b, []byte("\"override\"\n")) {
t.Fatalf("failed to switch to custom marshaler. got: %q", b)
}
})
t.Run("override bytes type", func(t *testing.T) {
type T struct {
Foo []byte `yaml:"foo"`
}
b, err := yaml.MarshalWithOptions(&T{Foo: []byte("bar")}, yaml.CustomMarshaler[[]byte](func(v []byte) ([]byte, error) {
if !bytes.Equal(v, []byte("bar")) {
t.Fatalf("failed to get src buffer: %q", v)
}
return []byte(`override`), nil
}))
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(b, []byte("foo: override\n")) {
t.Fatalf("failed to switch to custom marshaler. got: %q", b)
}
})
}

func TestEncoder_MultipleDocuments(t *testing.T) {
var buf bytes.Buffer
enc := yaml.NewEncoder(&buf)
Expand Down
30 changes: 30 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package yaml

import (
"io"
"reflect"

"github.com/goccy/go-yaml/ast"
)
Expand Down Expand Up @@ -94,6 +95,20 @@ func UseJSONUnmarshaler() DecodeOption {
}
}

// CustomUnmarshaler overrides any decoding process for the type specified in generics.
//
// NOTE: If RegisterCustomUnmarshaler and CustomUnmarshaler of DecodeOption are specified for the same type,
// the CustomUnmarshaler specified in DecodeOption takes precedence.
func CustomUnmarshaler[T any](unmarshaler func(*T, []byte) error) DecodeOption {
return func(d *Decoder) error {
var typ *T
d.customUnmarshalerMap[reflect.TypeOf(typ)] = func(v interface{}, b []byte) error {
return unmarshaler(v.(*T), b)
}
return nil
}
}

// EncodeOption functional option type for Encoder
type EncodeOption func(e *Encoder) error

Expand Down Expand Up @@ -165,6 +180,21 @@ func UseJSONMarshaler() EncodeOption {
}
}

// CustomMarshaler overrides any encoding process for the type specified in generics.
//
// NOTE: If type T implements MarshalYAML for pointer receiver, the type specified in CustomMarshaler must be *T.
// If RegisterCustomMarshaler and CustomMarshaler of EncodeOption are specified for the same type,
// the CustomMarshaler specified in EncodeOption takes precedence.
func CustomMarshaler[T any](marshaler func(T) ([]byte, error)) EncodeOption {
return func(e *Encoder) error {
var typ T
e.customMarshalerMap[reflect.TypeOf(typ)] = func(v interface{}) ([]byte, error) {
return marshaler(v.(T))
}
return nil
}
}

// CommentPosition type of the position for comment.
type CommentPosition int

Expand Down
40 changes: 40 additions & 0 deletions yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bytes"
"context"
"io"
"reflect"
"sync"

"github.com/goccy/go-yaml/ast"
"github.com/goccy/go-yaml/internal/errors"
Expand Down Expand Up @@ -248,3 +250,41 @@ func JSONToYAML(bytes []byte) ([]byte, error) {
}
return out, nil
}

var (
globalCustomMarshalerMu sync.Mutex
globalCustomUnmarshalerMu sync.Mutex
globalCustomMarshalerMap = map[reflect.Type]func(interface{}) ([]byte, error){}
globalCustomUnmarshalerMap = map[reflect.Type]func(interface{}, []byte) error{}
)

// RegisterCustomMarshaler overrides any encoding process for the type specified in generics.
// If you want to switch the behavior for each encoder, use `CustomMarshaler` defined as EncodeOption.
//
// NOTE: If type T implements MarshalYAML for pointer receiver, the type specified in RegisterCustomMarshaler must be *T.
// If RegisterCustomMarshaler and CustomMarshaler of EncodeOption are specified for the same type,
// the CustomMarshaler specified in EncodeOption takes precedence.
func RegisterCustomMarshaler[T any](marshaler func(T) ([]byte, error)) {
globalCustomMarshalerMu.Lock()
defer globalCustomMarshalerMu.Unlock()

var typ T
globalCustomMarshalerMap[reflect.TypeOf(typ)] = func(v interface{}) ([]byte, error) {
return marshaler(v.(T))
}
}

// RegisterCustomUnmarshaler overrides any decoding process for the type specified in generics.
// If you want to switch the behavior for each decoder, use `CustomUnmarshaler` defined as DecodeOption.
//
// NOTE: If RegisterCustomUnmarshaler and CustomUnmarshaler of DecodeOption are specified for the same type,
// the CustomUnmarshaler specified in DecodeOption takes precedence.
func RegisterCustomUnmarshaler[T any](unmarshaler func(*T, []byte) error) {
globalCustomUnmarshalerMu.Lock()
defer globalCustomUnmarshalerMu.Unlock()

var typ *T
globalCustomUnmarshalerMap[reflect.TypeOf(typ)] = func(v interface{}, b []byte) error {
return unmarshaler(v.(*T), b)
}
}
Loading