Skip to content

Commit

Permalink
Support custom struct tags (#347)
Browse files Browse the repository at this point in the history
* Implement custom tag support via msgp:tag directive
* Remove duplicate tag directive
* Add custom tag enc/dec test
  • Loading branch information
very-amused committed Jun 3, 2024
1 parent 2236701 commit abadd67
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 3 deletions.
9 changes: 9 additions & 0 deletions _generated/custom_tag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package _generated

//go:generate msgp
//msgp:tag mytag

type CustomTag struct {
Foo string `mytag:"foo_custom_name"`
Bar int `mytag:"bar1234"`
}
64 changes: 64 additions & 0 deletions _generated/custom_tag_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package _generated

import (
"encoding/json"
"fmt"
"reflect"
"testing"

"bytes"

"github.com/tinylib/msgp/msgp"
)

func TestCustomTag(t *testing.T) {
t.Run("File Scope", func(t *testing.T) {
ts := CustomTag{
Foo: "foostring13579",
Bar: 999_999}
encDecCustomTag(t, ts, "mytag")
})
}

func encDecCustomTag(t *testing.T, testStruct msgp.Encodable, tag string) {
var b bytes.Buffer
msgp.Encode(&b, testStruct)

// Check tag names using JSON as an intermediary layer
// TODO: is there a way to avoid the JSON layer? We'd need to directly decode raw msgpack -> map[string]any
refJSON, err := json.Marshal(testStruct)
if err != nil {
t.Error(fmt.Sprintf("error encoding struct as JSON: %v", err))
}
ref := make(map[string]any)
// Encoding and decoding the original struct via JSON is necessary
// for field comparisons to work, since JSON -> map[string]any
// relies on type inferences such as all numbers being float64s
json.Unmarshal(refJSON, &ref)

var encJSON bytes.Buffer
msgp.UnmarshalAsJSON(&encJSON, b.Bytes())
encoded := make(map[string]any)
json.Unmarshal(encJSON.Bytes(), &encoded)

tsType := reflect.TypeOf(testStruct)
for i := 0; i < tsType.NumField(); i++ {
// Check encoded field name
field := tsType.Field(i)
encodedValue, ok := encoded[field.Tag.Get(tag)]
if !ok {
t.Error("missing encoded value for field", field.Name)
continue
}
// Check encoded field value (against original value post-JSON enc + dec)
jsonName, ok := field.Tag.Lookup("json")
if !ok {
jsonName = field.Name
}
refValue := ref[jsonName]
if !reflect.DeepEqual(refValue, encodedValue) {
t.Error(fmt.Sprintf("incorrect encoded value for field %s. reference: %v, encoded: %v",
field.Name, refValue, encodedValue))
}
}
}
20 changes: 18 additions & 2 deletions parse/directives.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,15 @@ type passDirective func(gen.Method, []string, *gen.Printer) error
var directives = map[string]directive{
"shim": applyShim,
"ignore": ignore,
"tuple": astuple,
}
"tuple": astuple}

// map of all recognized directives which will be applied
// before process() is called
//
// to add an early directive, define a func([]string, *FileSet) error
// and then add it to this list.
var earlyDirectives = map[string]directive{
"tag": tag}

var passDirectives = map[string]passDirective{
"ignore": passignore,
Expand Down Expand Up @@ -128,3 +135,12 @@ func astuple(text []string, f *FileSet) error {
}
return nil
}

//msgp:tag {tagname}
func tag(text []string, f *FileSet) error {
if len(text) != 2 {
return nil
}
f.tagName = strings.TrimSpace(text[1])
return nil
}
33 changes: 32 additions & 1 deletion parse/getast.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type FileSet struct {
Identities map[string]gen.Elem // processed from specs
Directives []string // raw preprocessor directives
Imports []*ast.ImportSpec // imports
tagName string // tag to read field names from
}

// File parses a file at the relative path
Expand Down Expand Up @@ -82,6 +83,7 @@ func File(name string, unexported bool) (*FileSet, error) {
return nil, fmt.Errorf("no definitions in %s", name)
}

fs.applyEarlyDirectives()
fs.process()
fs.applyDirectives()
fs.propInline()
Expand Down Expand Up @@ -112,6 +114,29 @@ func (f *FileSet) applyDirectives() {
f.Directives = newdirs
}

// applyEarlyDirectives applies all early directives needed before process() is called.
// additional directives remain in f.Directives for future processing
func (f *FileSet) applyEarlyDirectives() {
newdirs := make([]string, 0, len(f.Directives))
for _, d := range f.Directives {
parts := strings.Split(d, " ")
if len(parts) == 0 {
continue
}
if fn, ok := earlyDirectives[parts[0]]; ok {
pushstate(parts[0])
err := fn(parts, f)
if err != nil {
warnf("early directive error: %s", err)
}
popstate()
} else {
newdirs = append(newdirs, d)
}
}
f.Directives = newdirs
}

// A linkset is a graph of unresolved
// identities.
//
Expand Down Expand Up @@ -329,7 +354,13 @@ func (fs *FileSet) getField(f *ast.Field) []gen.StructField {
var extension, flatten bool
// parse tag; otherwise field name is field tag
if f.Tag != nil {
body := reflect.StructTag(strings.Trim(f.Tag.Value, "`")).Get("msg")
var body string
if fs.tagName != "" {
body = reflect.StructTag(strings.Trim(f.Tag.Value, "`")).Get(fs.tagName)
}
if body == "" {
body = reflect.StructTag(strings.Trim(f.Tag.Value, "`")).Get("msg")
}
if body == "" {
body = reflect.StructTag(strings.Trim(f.Tag.Value, "`")).Get("msgpack")
}
Expand Down

0 comments on commit abadd67

Please sign in to comment.