Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allownil: Allocate 0 length slices #336

Merged
merged 4 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions _generated/allownil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package _generated

import (
"bytes"
"reflect"
"testing"

"github.com/tinylib/msgp/msgp"
)

func TestAllownil(t *testing.T) {
tt := &NamedStructAN{
A: []string{},
B: nil,
}
var buf bytes.Buffer

err := msgp.Encode(&buf, tt)
if err != nil {
t.Fatal(err)
}
in := buf.Bytes()

for _, tnew := range []*NamedStructAN{{}, {A: []string{}}, {B: []string{}}} {
err = msgp.Decode(bytes.NewBuffer(in), tnew)
if err != nil {
t.Error(err)
}

if !reflect.DeepEqual(tt, tnew) {
t.Logf("in: %#v", tt)
t.Logf("out: %#v", tnew)
t.Fatal("objects not equal")
}
}

in, err = tt.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
for _, tanother := range []*NamedStructAN{{}, {A: []string{}}, {B: []string{}}} {
var left []byte
left, err = tanother.UnmarshalMsg(in)
if err != nil {
t.Error(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left", len(left))
}

if !reflect.DeepEqual(tt, tanother) {
t.Logf("in: %#v", tt)
t.Logf("out: %#v", tanother)
t.Fatal("objects not equal")
}
}
}
4 changes: 2 additions & 2 deletions _generated/gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ func Test1EncodeDecode(t *testing.T) {
}

if !tt.Equal(tnew) {
t.Logf("in: %v", tt)
t.Logf("out: %v", tnew)
t.Logf("in: %#v", tt)
t.Logf("out: %#v", tnew)
t.Fatal("objects not equal")
}

Expand Down
10 changes: 0 additions & 10 deletions _generated/issue94.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,6 @@ import (

//go:generate msgp

// Issue 94: shims were not propogated recursively,
// which caused shims that weren't at the top level
// to be silently ignored.
//
// The following line will generate an error after
// the code is generated if the generated code doesn't
// have the right identifier in it.

//go:generate ./search.sh $GOFILE timetostr

//msgp:shim time.Time as:string using:timetostr/strtotime
type T struct {
T time.Time
Expand Down
25 changes: 25 additions & 0 deletions _generated/issue94_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package _generated

import (
"bytes"
"os"
"testing"
)

// Issue 94: shims were not propogated recursively,
// which caused shims that weren't at the top level
// to be silently ignored.
//
// The following line will generate an error after
// the code is generated if the generated code doesn't
// have the right identifier in it.
func TestIssue94(t *testing.T) {
b, err := os.ReadFile("issue94_gen.go")
if err != nil {
t.Fatal(err)
}
const want = "timetostr"
if !bytes.Contains(b, []byte(want)) {
t.Errorf("generated code did not contain %q", want)
}
}
12 changes: 0 additions & 12 deletions _generated/search.sh

This file was deleted.

20 changes: 14 additions & 6 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,17 @@ func (d *decodeGen) structAsTuple(s *Struct) {
if !d.p.ok() {
return
}
anField := s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
d.p.print("\nif dc.IsNil() {")
d.p.print("\nerr = dc.ReadNil()")
d.p.wrapErrCheck(d.ctx.ArgsStr())
d.p.printf("\n%s = nil\n} else {", s.Fields[i].FieldElem.Varname())
}
SetIsAllowNil(fieldElem, anField)
d.ctx.PushString(s.Fields[i].FieldName)
next(d, s.Fields[i].FieldElem)
next(d, fieldElem)
d.ctx.Pop()
if anField {
d.p.printf("\n}") // close if statement
Expand All @@ -112,14 +114,16 @@ func (d *decodeGen) structAsMap(s *Struct) {
for i := range s.Fields {
d.ctx.PushString(s.Fields[i].FieldName)
d.p.printf("\ncase \"%s\":", s.Fields[i].FieldTag)
anField := s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
d.p.print("\nif dc.IsNil() {")
d.p.print("\nerr = dc.ReadNil()")
d.p.wrapErrCheck(d.ctx.ArgsStr())
d.p.printf("\n%s = nil\n} else {", s.Fields[i].FieldElem.Varname())
d.p.printf("\n%s = nil\n} else {", fieldElem.Varname())
}
next(d, s.Fields[i].FieldElem)
SetIsAllowNil(fieldElem, anField)
next(d, fieldElem)
d.ctx.Pop()
if !d.p.ok() {
return
Expand Down Expand Up @@ -215,7 +219,11 @@ func (d *decodeGen) gSlice(s *Slice) {
sz := randIdent()
d.p.declare(sz, u32)
d.assignAndCheck(sz, arrayHeader)
d.p.resizeSlice(sz, s)
if s.isAllowNil {
d.p.resizeSliceNoNil(sz, s)
} else {
d.p.resizeSlice(sz, s)
}
d.p.rangeBlock(d.ctx, s.Index, s.Varname(), d, s.Els)
}

Expand Down
28 changes: 23 additions & 5 deletions gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,10 @@ func (a *Array) IfZeroExpr() string { return "" }
// Map is a map[string]Elem
type Map struct {
common
Keyidx string // key variable name
Validx string // value variable name
Value Elem // value element
Keyidx string // key variable name
Validx string // value variable name
Value Elem // value element
isAllowNil bool
}

func (m *Map) SetVarname(s string) {
Expand Down Expand Up @@ -302,10 +303,14 @@ func (m *Map) IfZeroExpr() string { return m.Varname() + " == nil" }
// AllowNil is true for maps.
func (m *Map) AllowNil() bool { return true }

// SetIsAllowNil sets whether the map is allowed to be nil.
func (m *Map) SetIsAllowNil(b bool) { m.isAllowNil = b }

type Slice struct {
common
Index string
Els Elem // The type of each element
Index string
isAllowNil bool
Els Elem // The type of each element
}

func (s *Slice) SetVarname(a string) {
Expand Down Expand Up @@ -346,6 +351,19 @@ func (s *Slice) IfZeroExpr() string { return s.Varname() + " == nil" }
// AllowNil is true for slices.
func (s *Slice) AllowNil() bool { return true }

// SetIsAllowNil sets whether the slice is allowed to be nil.
func (s *Slice) SetIsAllowNil(b bool) { s.isAllowNil = b }

// SetIsAllowNil will set whether the element is allowed to be nil.
func SetIsAllowNil(e Elem, b bool) {
type i interface {
SetIsAllowNil(b bool)
}
if x, ok := e.(i); ok {
x.SetIsAllowNil(b)
}
}

type Ptr struct {
common
Value Elem
Expand Down
11 changes: 7 additions & 4 deletions gen/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,14 @@ func (e *encodeGen) tuple(s *Struct) {
if !e.p.ok() {
return
}
anField := s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
e.p.printf("\nif %s { // allownil: if nil", s.Fields[i].FieldElem.IfZeroExpr())
e.p.printf("\nif %s { // allownil: if nil", fieldElem.IfZeroExpr())
e.p.printf("\nerr = en.WriteNil(); if err != nil { return; }")
e.p.printf("\n} else {")
}
SetIsAllowNil(fieldElem, anField)
e.ctx.PushString(s.Fields[i].FieldName)
next(e, s.Fields[i].FieldElem)
e.ctx.Pop()
Expand Down Expand Up @@ -189,13 +191,14 @@ func (e *encodeGen) structmap(s *Struct) {
e.p.printf("\n// write %q", s.Fields[i].FieldTag)
e.Fuse(data)
e.fuseHook()

anField := !oeField && s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := !oeField && s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
e.p.printf("\nif %s { // allownil: if nil", s.Fields[i].FieldElem.IfZeroExpr())
e.p.printf("\nerr = en.WriteNil(); if err != nil { return; }")
e.p.printf("\n} else {")
}
SetIsAllowNil(fieldElem, anField)

e.ctx.PushString(s.Fields[i].FieldName)
next(e, s.Fields[i].FieldElem)
Expand Down
17 changes: 10 additions & 7 deletions gen/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,16 @@ func (m *marshalGen) tuple(s *Struct) {
if !m.p.ok() {
return
}
anField := s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
m.p.printf("\nif %s { // allownil: if nil", s.Fields[i].FieldElem.IfZeroExpr())
m.p.printf("\nif %s { // allownil: if nil", fieldElem.IfZeroExpr())
m.p.printf("\no = msgp.AppendNil(o)")
m.p.printf("\n} else {")
}
m.ctx.PushString(s.Fields[i].FieldName)
next(m, s.Fields[i].FieldElem)
SetIsAllowNil(fieldElem, anField)
next(m, fieldElem)
m.ctx.Pop()
if anField {
m.p.printf("\n}") // close if statement
Expand Down Expand Up @@ -186,15 +188,16 @@ func (m *marshalGen) mapstruct(s *Struct) {
m.Fuse(data)
m.fuseHook()

anField := !oeField && s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := !oeField && s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
m.p.printf("\nif %s { // allownil: if nil", s.Fields[i].FieldElem.IfZeroExpr())
m.p.printf("\nif %s { // allownil: if nil", fieldElem.IfZeroExpr())
m.p.printf("\no = msgp.AppendNil(o)")
m.p.printf("\n} else {")
}

m.ctx.PushString(s.Fields[i].FieldName)
next(m, s.Fields[i].FieldElem)
SetIsAllowNil(fieldElem, anField)
next(m, fieldElem)
m.ctx.Pop()

if oeField || anField {
Expand Down
7 changes: 7 additions & 0 deletions gen/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ func (p *printer) resizeSlice(size string, s *Slice) {
p.printf("\nif cap(%[1]s) >= int(%[2]s) { %[1]s = (%[1]s)[:%[2]s] } else { %[1]s = make(%[3]s, %[2]s) }", s.Varname(), size, s.TypeName())
}

// resizeSliceNoNil will resize a slice and will not allow nil slices.
func (p *printer) resizeSliceNoNil(size string, s *Slice) {
p.printf("\nif %[1]s != nil && cap(%[1]s) >= int(%[2]s) {", s.Varname(), size)
p.printf("\n%[1]s = (%[1]s)[:%[2]s]", s.Varname(), size)
p.printf("\n} else { %[1]s = make(%[3]s, %[2]s) }", s.Varname(), size, s.TypeName())
}

func (p *printer) arrayCheck(want string, got string) {
p.printf("\nif %[1]s != %[2]s { err = msgp.ArrayError{Wanted: %[2]s, Got: %[1]s}; return }", got, want)
}
Expand Down
22 changes: 15 additions & 7 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,13 @@ func (u *unmarshalGen) tuple(s *Struct) {
return
}
u.ctx.PushString(s.Fields[i].FieldName)
anField := s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
u.p.printf("\nif msgp.IsNil(bts) {\nbts = bts[1:]\n%s = nil\n} else {", s.Fields[i].FieldElem.Varname())
u.p.printf("\nif msgp.IsNil(bts) {\nbts = bts[1:]\n%s = nil\n} else {", fieldElem.Varname())
}
next(u, s.Fields[i].FieldElem)
SetIsAllowNil(fieldElem, anField)
next(u, fieldElem)
u.ctx.Pop()
if anField {
u.p.printf("\n}")
Expand All @@ -113,11 +115,13 @@ func (u *unmarshalGen) mapstruct(s *Struct) {
u.p.printf("\ncase \"%s\":", s.Fields[i].FieldTag)
u.ctx.PushString(s.Fields[i].FieldName)

anField := s.Fields[i].HasTagPart("allownil") && s.Fields[i].FieldElem.AllowNil()
fieldElem := s.Fields[i].FieldElem
anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil()
if anField {
u.p.printf("\nif msgp.IsNil(bts) {\nbts = bts[1:]\n%s = nil\n} else {", s.Fields[i].FieldElem.Varname())
u.p.printf("\nif msgp.IsNil(bts) {\nbts = bts[1:]\n%s = nil\n} else {", fieldElem.Varname())
}
next(u, s.Fields[i].FieldElem)
SetIsAllowNil(fieldElem, anField)
next(u, fieldElem)
u.ctx.Pop()
if anField {
u.p.printf("\n}")
Expand Down Expand Up @@ -193,7 +197,11 @@ func (u *unmarshalGen) gSlice(s *Slice) {
sz := randIdent()
u.p.declare(sz, u32)
u.assignAndCheck(sz, arrayHeader)
u.p.resizeSlice(sz, s)
if s.isAllowNil {
u.p.resizeSliceNoNil(sz, s)
} else {
u.p.resizeSlice(sz, s)
}
u.p.rangeBlock(u.ctx, s.Index, s.Varname(), u, s.Els)
}

Expand Down
Loading