diff --git a/_generated/def.go b/_generated/def.go index dd7d7020..062142f0 100644 --- a/_generated/def.go +++ b/_generated/def.go @@ -194,7 +194,7 @@ type Custom struct { Bts CustomBytes `msg:"bts"` Mp map[string]*Embedded `msg:"mp"` Enums []MyEnum `msg:"enums"` // test explicit enum shim - Some FileHandle `msg:file_handle` + Some FileHandle `msg:"file_handle"` } type Files []*os.File diff --git a/_generated/intercept_defs.go b/_generated/intercept_defs.go new file mode 100644 index 00000000..60f7835a --- /dev/null +++ b/_generated/intercept_defs.go @@ -0,0 +1,285 @@ +package _generated + +import ( + "fmt" + + "github.com/tinylib/msgp/msgp" +) + +//go:generate msgp + +//msgp:ignore testStructProvider +type testStructProvider struct { + Events []string +} + +var prv = &testStructProvider{} + +func resetStructProvider() { + prv = &testStructProvider{} +} + +func TestStructProvider() *testStructProvider { + return prv +} + +//msgp:intercept TestStructProvided using:TestStructProvider + +type TestStructProvided struct { + Foo string +} + +type TestUsesStructProvided struct { + Foo *TestStructProvided +} + +func (p *testStructProvider) DecodeMsg(dc *msgp.Reader) (t *TestStructProvided, err error) { + p.Events = append(p.Events, "decode") + t = new(TestStructProvided) + err = t.DecodeMsg(dc) + return +} + +func (p *testStructProvider) UnmarshalMsg(bts []byte) (t *TestStructProvided, o []byte, err error) { + t = new(TestStructProvided) + p.Events = append(p.Events, "unmarshal") + o, err = t.UnmarshalMsg(bts) + return +} + +func (p *testStructProvider) EncodeMsg(t *TestStructProvided, en *msgp.Writer) (err error) { + p.Events = append(p.Events, "encode") + return t.EncodeMsg(en) +} + +func (p *testStructProvider) MarshalMsg(t *TestStructProvided, b []byte) (o []byte, err error) { + p.Events = append(p.Events, "marshal") + return t.MarshalMsg(b) +} + +func (p *testStructProvider) Msgsize(t *TestStructProvided) (s int) { + p.Events = append(p.Events, "msgsize") + return t.Msgsize() +} + +//msgp:ignore testStringProvided +type testStringProvider struct { + Events []string +} + +var stringPrv = &testStringProvider{} + +func resetStringProvider() { + stringPrv = &testStringProvider{} +} + +func TestStringProvider() *testStringProvider { + return stringPrv +} + +//msgp:intercept TestStringProvided using:TestStringProvider +type TestStringProvided string + +type TestUsesStringProvided struct { + Foo TestStringProvided +} + +func (p *testStringProvider) DecodeMsg(dc *msgp.Reader) (t TestStringProvided, err error) { + p.Events = append(p.Events, "decode") + var s string + s, err = dc.ReadString() + if err != nil { + return + } + t = TestStringProvided(s) + return +} + +func (p *testStringProvider) UnmarshalMsg(bts []byte) (t TestStringProvided, o []byte, err error) { + p.Events = append(p.Events, "unmarshal") + var s string + s, o, err = msgp.ReadStringBytes(bts) + if err != nil { + return + } + t = TestStringProvided(s) + return +} + +func (p *testStringProvider) EncodeMsg(t TestStringProvided, en *msgp.Writer) (err error) { + p.Events = append(p.Events, "encode") + return en.WriteString(string(t)) +} + +func (p *testStringProvider) MarshalMsg(t TestStringProvided, b []byte) (o []byte, err error) { + p.Events = append(p.Events, "marshal") + o = msgp.AppendString(b, string(t)) + return +} + +func (p *testStringProvider) Msgsize(t TestStringProvided) (s int) { + return msgp.StringPrefixSize + len(t) +} + +//msgp:ignore testIntfStructProvider +type testIntfStructProvider struct { + Events []string +} + +var intfStructPrv = &testIntfStructProvider{} + +func resetIntfStructProvider() { + intfStructPrv = &testIntfStructProvider{} +} + +func TestIntfStructProvider() *testIntfStructProvider { + return intfStructPrv +} + +//msgp:intercept TestIntfStructProvided using:TestIntfStructProvider +type TestIntfStructProvided interface { + msgp.Decodable + msgp.Encodable + msgp.MarshalSizer + msgp.Unmarshaler +} + +type TestUsesIntfStructProvided struct { + Foo TestIntfStructProvided +} + +type TestUsesIntfStructProvidedSlice struct { + Foo []TestIntfStructProvided +} + +type TestUsesIntfStructProvidedMap struct { + Foo map[string]TestIntfStructProvided +} + +type TestIntfA struct { + Foo string +} + +type TestIntfB struct { + Bar string +} + +func (p *testIntfStructProvider) DecodeMsg(dc *msgp.Reader) (t TestIntfStructProvided, err error) { + p.Events = append(p.Events, "decode") + + if dc.IsNil() { + err = dc.ReadNil() + } else { + var s string + var sz uint32 + if sz, err = dc.ReadArrayHeader(); err != nil { + return + } + if sz != 2 { + err = fmt.Errorf("unexpected array length") + return + } + s, err = dc.ReadString() + if err != nil { + return + } + switch s { + case "a": + t = new(TestIntfA) + case "b": + t = new(TestIntfB) + default: + err = fmt.Errorf("unexpected type") + return + } + err = t.DecodeMsg(dc) + } + return +} + +func (p *testIntfStructProvider) UnmarshalMsg(bts []byte) (t TestIntfStructProvided, o []byte, err error) { + p.Events = append(p.Events, "unmarshal") + + o = bts + if msgp.IsNil(bts) { + o, err = msgp.ReadNilBytes(o) + } else { + var s string + var sz uint32 + if sz, o, err = msgp.ReadArrayHeaderBytes(o); err != nil { + return + } + if sz != 2 { + err = fmt.Errorf("unexpected array length") + return + } + s, o, err = msgp.ReadStringBytes(o) + if err != nil { + return + } + switch s { + case "a": + t = new(TestIntfA) + case "b": + t = new(TestIntfB) + default: + err = fmt.Errorf("unexpected type") + return + } + o, err = t.UnmarshalMsg(o) + } + return +} + +func (p *testIntfStructProvider) EncodeMsg(t TestIntfStructProvided, en *msgp.Writer) (err error) { + p.Events = append(p.Events, "encode") + if t == nil { + return en.WriteNil() + } else { + if err = en.WriteArrayHeader(2); err != nil { + return + } + var s string + switch t.(type) { + case *TestIntfA: + s = "a" + case *TestIntfB: + s = "b" + default: + err = fmt.Errorf("unexpected type %T", t) + } + if err = en.WriteString(s); err != nil { + return + } + return t.EncodeMsg(en) + } +} + +func (p *testIntfStructProvider) MarshalMsg(t TestIntfStructProvided, b []byte) (o []byte, err error) { + p.Events = append(p.Events, "marshal") + o = b + if t == nil { + o = msgp.AppendNil(o) + return + } else { + o = msgp.AppendArrayHeader(o, 2) + var s string + switch t.(type) { + case *TestIntfA: + s = "a" + case *TestIntfB: + s = "b" + default: + err = fmt.Errorf("unexpected type %T", t) + } + o = msgp.AppendString(o, s) + return t.MarshalMsg(o) + } +} + +func (p *testIntfStructProvider) Msgsize(t TestIntfStructProvided) (s int) { + if t == nil { + return msgp.NilSize + } + return t.Msgsize() +} diff --git a/_generated/intercept_test.go b/_generated/intercept_test.go new file mode 100644 index 00000000..e597927c --- /dev/null +++ b/_generated/intercept_test.go @@ -0,0 +1,271 @@ +package _generated + +import ( + "bytes" + "reflect" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestInterceptEncodeDecodeStruct(t *testing.T) { + resetStructProvider() + + in := TestUsesStructProvided{Foo: &TestStructProvided{Foo: "hi"}} + + var buf bytes.Buffer + wrt := msgp.NewWriter(&buf) + if err := in.EncodeMsg(wrt); err != nil { + t.Errorf("%v", err) + } + wrt.Flush() + + var out TestUsesStructProvided + rdr := msgp.NewReader(&buf) + if err := (&out).DecodeMsg(rdr); err != nil { + t.Errorf("%v", err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatalf("provided encode decode failed") + } + + if !reflect.DeepEqual([]string{"encode", "decode"}, TestStructProvider().Events) { + t.Fatalf("unexpected events: %v", TestStructProvider().Events) + } +} + +func TestInterceptMarshalUnmarshalStruct(t *testing.T) { + resetStructProvider() + + in := TestUsesStructProvided{Foo: &TestStructProvided{Foo: "hi"}} + + bts, err := in.MarshalMsg(nil) + if err != nil { + t.Fatalf("%v", err) + } + + var out TestUsesStructProvided + if _, err := (&out).UnmarshalMsg(bts); err != nil { + t.Fatalf("%v", err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatalf("provided unmarshal failed") + } + + if !reflect.DeepEqual([]string{"msgsize", "marshal", "unmarshal"}, TestStructProvider().Events) { + t.Fatalf("unexpected events: %v", TestStructProvider().Events) + } +} + +func TestInterceptEncodeDecodeString(t *testing.T) { + resetStringProvider() + + in := TestUsesStringProvided{Foo: TestStringProvided("hi")} + + var buf bytes.Buffer + wrt := msgp.NewWriter(&buf) + if err := in.EncodeMsg(wrt); err != nil { + t.Errorf("%v", err) + } + wrt.Flush() + + var out TestUsesStringProvided + rdr := msgp.NewReader(&buf) + if err := (&out).DecodeMsg(rdr); err != nil { + t.Errorf("%v", err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatalf("provided encode decode failed") + } + + if !reflect.DeepEqual([]string{"encode", "decode"}, TestStringProvider().Events) { + t.Fatalf("unexpected events: %v", TestStringProvider().Events) + } +} + +func TestInterceptMarshalUnmarshalString(t *testing.T) { + resetStringProvider() + + in := TestUsesStringProvided{Foo: TestStringProvided("hi")} + + bts, err := in.MarshalMsg(nil) + if err != nil { + t.Fatalf("%v", err) + } + + var out TestUsesStringProvided + if _, err := (&out).UnmarshalMsg(bts); err != nil { + t.Fatalf("%v", err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatalf("provided unmarshal failed") + } + + if !reflect.DeepEqual([]string{"marshal", "unmarshal"}, TestStringProvider().Events) { + t.Fatalf("unexpected events, found: %v", TestStringProvider().Events) + } +} + +func TestInterceptInterfaceEncodeDecode(t *testing.T) { + cases := []TestUsesIntfStructProvided{ + {Foo: &TestIntfA{Foo: "hello"}}, + {Foo: &TestIntfB{Bar: "world"}}, + {Foo: nil}, + } + + for _, in := range cases { + resetIntfStructProvider() + + var buf bytes.Buffer + wrt := msgp.NewWriter(&buf) + if err := in.EncodeMsg(wrt); err != nil { + t.Errorf("%v", err) + } + wrt.Flush() + + var out TestUsesIntfStructProvided + rdr := msgp.NewReader(&buf) + if err := (&out).DecodeMsg(rdr); err != nil { + t.Errorf("%v", err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatalf("provided encode decode failed") + } + + if !reflect.DeepEqual([]string{"encode", "decode"}, TestIntfStructProvider().Events) { + t.Fatalf("unexpected events: %v", TestIntfStructProvider().Events) + } + } +} + +func TestInterceptInterfaceMarshalUnmarshal(t *testing.T) { + cases := []TestUsesIntfStructProvided{ + {Foo: &TestIntfA{Foo: "hello"}}, + {Foo: &TestIntfB{Bar: "world"}}, + {Foo: nil}, + } + + for _, in := range cases { + resetIntfStructProvider() + + bts, err := in.MarshalMsg(nil) + if err != nil { + t.Fatalf("%v", err) + } + + var out TestUsesIntfStructProvided + if _, err := (&out).UnmarshalMsg(bts); err != nil { + t.Fatalf("%v", err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatalf("provided marshal/unmarshal failed") + } + + if !reflect.DeepEqual([]string{"marshal", "unmarshal"}, TestIntfStructProvider().Events) { + t.Fatalf("unexpected events: %v", TestIntfStructProvider().Events) + } + } +} + +func TestInterceptInterfaceSliceMarshalUnmarshal(t *testing.T) { + cases := []TestUsesIntfStructProvidedSlice{ + {Foo: []TestIntfStructProvided{ + &TestIntfA{Foo: "hello"}, + &TestIntfB{Bar: "world"}, + }}, + + // FIXME: empty slice unmarshals as nil, is this msgp? + // {Foo: []TestIntfStructProvided{}}, + } + + for _, in := range cases { + resetIntfStructProvider() + + bts, err := in.MarshalMsg(nil) + if err != nil { + t.Fatalf("%v", err) + } + + var out TestUsesIntfStructProvidedSlice + if _, err := (&out).UnmarshalMsg(bts); err != nil { + t.Fatalf("%v", err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatalf("provided marshal/unmarshal failed") + } + + if !reflect.DeepEqual([]string{"marshal", "marshal", "unmarshal", "unmarshal"}, TestIntfStructProvider().Events) { + t.Fatalf("unexpected events: %v", TestIntfStructProvider().Events) + } + } +} + +func TestInterceptInterfaceMapMarshalUnmarshal(t *testing.T) { + cases := []TestUsesIntfStructProvidedMap{ + {Foo: map[string]TestIntfStructProvided{ + "a": &TestIntfA{Foo: "hello"}, + "b": &TestIntfB{Bar: "world"}, + }}, + + // FIXME: empty slice unmarshals as nil, is this msgp? + // {Foo: []TestIntfStructProvided{}}, + } + + for _, in := range cases { + resetIntfStructProvider() + + bts, err := in.MarshalMsg(nil) + if err != nil { + t.Fatalf("%v", err) + } + + var out TestUsesIntfStructProvidedMap + if _, err := (&out).UnmarshalMsg(bts); err != nil { + t.Fatalf("%v", err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatalf("provided marshal/unmarshal failed") + } + + if !reflect.DeepEqual([]string{"marshal", "marshal", "unmarshal", "unmarshal"}, TestIntfStructProvider().Events) { + t.Fatalf("unexpected events: %v", TestIntfStructProvider().Events) + } + } +} + +func TestInterceptInterfaceUnmarshalAsJSON(t *testing.T) { + cases := []struct { + in TestUsesIntfStructProvided + out string + }{ + {in: TestUsesIntfStructProvided{Foo: &TestIntfA{Foo: "hello"}}, out: `{"Foo":["a",{"Foo":"hello"}]}`}, + {in: TestUsesIntfStructProvided{Foo: &TestIntfB{Bar: "world"}}, out: `{"Foo":["b",{"Bar":"world"}]}`}, + {in: TestUsesIntfStructProvided{Foo: nil}, out: `{"Foo":null}`}, + } + + for idx, tcase := range cases { + resetIntfStructProvider() + + bts, err := tcase.in.MarshalMsg(nil) + if err != nil { + t.Fatalf("%v", err) + } + + var buf bytes.Buffer + if _, err := msgp.UnmarshalAsJSON(&buf, bts); err != nil { + t.Fatalf("%v", err) + } + + if tcase.out != buf.String() { + t.Fatalf("%d: unexpected JSON `%s`", idx, buf.String()) + } + } +} diff --git a/gen/decode.go b/gen/decode.go index 3ba88fad..5d04178e 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -133,7 +133,11 @@ func (d *decodeGen) gBase(b *BaseElem) { d.p.printf("\n%s, err = dc.ReadBytes(%s)", vname, vname) } case IDENT: - d.p.printf("\nerr = %s.DecodeMsg(dc)", vname) + if b.Provider() != "" { + d.p.printf("\n%s, err = %s().DecodeMsg(dc)", vname, b.Provider()) + } else { + d.p.printf("\nerr = %s.DecodeMsg(dc)", vname) + } case Ext: d.p.printf("\nerr = dc.ReadExtension(%s)", vname) default: diff --git a/gen/elem.go b/gen/elem.go index 5da62b7d..e6843279 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -23,6 +23,12 @@ func randIdent() string { return "z" + string(bts) } +// interface checks +var ( + _ Intercepted = &BaseElem{} + _ Intercepted = &Struct{} +) + // This code defines the type declaration tree. // // Consider the following: @@ -136,6 +142,7 @@ var builtins = map[string]struct{}{ type common struct{ vname, alias string } func (c *common) SetVarname(s string) { c.vname = s } +func (c *common) Printable() bool { return true } func (c *common) Varname() string { return c.vname } func (c *common) Alias(typ string) { c.alias = typ } func (c *common) hidden() {} @@ -180,6 +187,8 @@ type Elem interface { // or equal to 1.) Complexity() int + Printable() bool + hidden() } @@ -353,18 +362,50 @@ func (s *Ptr) Copy() Elem { func (s *Ptr) Complexity() int { return 1 + s.Value.Complexity() } func (s *Ptr) Needsinit() bool { + if IsIntercepted(s.Value) { + return false + } if be, ok := s.Value.(*BaseElem); ok && be.needsref { return false } return true } +type IntfElem struct { + common + provider string // struct is intercepted +} + +func (s *IntfElem) Printable() bool { return false } + +func (s *IntfElem) Provider() string { return s.provider } + +func (s *IntfElem) SetProvider(p string) { s.provider = p } + +func (s *IntfElem) TypeName() string { + return s.common.alias +} + +func (s *IntfElem) Copy() Elem { + z := *s + return &z +} + +func (s *IntfElem) Complexity() int { + return 1 +} + type Struct struct { common - Fields []StructField // field list - AsTuple bool // write as an array instead of a map + Fields []StructField // field list + AsTuple bool // write as an array instead of a map + provider string // struct is intercepted } +func (s *Struct) Provider() string { return s.provider } + +func (s *Struct) SetProvider(p string) { s.provider = p } + func (s *Struct) TypeName() string { if s.common.alias != "" { return s.common.alias @@ -424,12 +465,17 @@ type BaseElem struct { ShimFromBase string // shim from base type, or empty Value Primitive // Type of element Convert bool // should we do an explicit conversion? + provider string // base elem is intercepted mustinline bool // must inline; not printable needsref bool // needs reference for shim } func (s *BaseElem) Printable() bool { return !s.mustinline } +func (s *BaseElem) Provider() string { return s.provider } + +func (s *BaseElem) SetProvider(p string) { s.provider = p } + func (s *BaseElem) Alias(typ string) { s.common.Alias(typ) if s.Value != IDENT { @@ -606,6 +652,18 @@ func writeStructFields(s []StructField, name string) { } } +type Intercepted interface { + Provider() string + SetProvider(s string) +} + +func IsIntercepted(e Elem) bool { + if p, ok := e.(Intercepted); ok { + return p.Provider() != "" + } + return false +} + // coerceArraySize ensures we can compare constant array lengths. // // msgpack array headers are 32 bit unsigned, which is reflected in the diff --git a/gen/encode.go b/gen/encode.go index f2c9ff3e..37e16cf2 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -187,7 +187,11 @@ func (e *encodeGen) gBase(b *BaseElem) { } if b.Value == IDENT { // unknown identity - e.p.printf("\nerr = %s.EncodeMsg(en)", vname) + if b.Provider() != "" { + e.p.printf("\nerr = %s().EncodeMsg(%s, en)", b.Provider(), vname) + } else { + e.p.printf("\nerr = %s.EncodeMsg(en)", vname) + } e.p.print(errcheck) } else { // typical case e.writeAndCheck(b.BaseName(), literalFmt, vname) diff --git a/gen/marshal.go b/gen/marshal.go index e92b44a3..7e595333 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -195,7 +195,11 @@ func (m *marshalGen) gBase(b *BaseElem) { switch b.Value { case IDENT: echeck = true - m.p.printf("\no, err = %s.MarshalMsg(o)", vname) + if b.Provider() != "" { + m.p.printf("\no, err = %s().MarshalMsg(%s, o)", b.Provider(), vname) + } else { + m.p.printf("\no, err = %s.MarshalMsg(o)", vname) + } case Intf, Ext: echeck = true m.p.printf("\no, err = msgp.Append%s(o, %s)", b.BaseName(), vname) diff --git a/gen/size.go b/gen/size.go index fd9cc421..7dd9cb3d 100644 --- a/gen/size.go +++ b/gen/size.go @@ -192,7 +192,7 @@ func (s *sizeGen) gBase(b *BaseElem) { // ensure we don't get "unused variable" warnings from outer slice iterations s.p.printf("\n_ = %s", b.Varname()) - s.p.printf("\ns += %s", basesizeExpr(b.Value, vname, b.BaseName())) + s.p.printf("\ns += %s", basesizeExpr(b, b.Value, vname, b.BaseName())) s.state = expr } else { @@ -200,7 +200,7 @@ func (s *sizeGen) gBase(b *BaseElem) { if b.Convert { vname = tobaseConvert(b) } - s.addConstant(basesizeExpr(b.Value, vname, b.BaseName())) + s.addConstant(basesizeExpr(b, b.Value, vname, b.BaseName())) } } @@ -268,14 +268,18 @@ func fixedsizeExpr(e Elem) (string, bool) { } // print size expression of a variable name -func basesizeExpr(value Primitive, vname, basename string) string { +func basesizeExpr(e Elem, value Primitive, vname, basename string) string { switch value { case Ext: return "msgp.ExtensionPrefixSize + " + stripRef(vname) + ".Len()" case Intf: return "msgp.GuessSize(" + vname + ")" case IDENT: - return vname + ".Msgsize()" + if p, ok := e.(Intercepted); ok && p.Provider() != "" { + return fmt.Sprintf("%s().Msgsize(%s)", p.Provider(), vname) + } else { + return vname + ".Msgsize()" + } case Bytes: return "msgp.BytesPrefixSize + len(" + vname + ")" case String: diff --git a/gen/unmarshal.go b/gen/unmarshal.go index 283beaff..b4a8c83d 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -128,7 +128,11 @@ func (u *unmarshalGen) gBase(b *BaseElem) { case Ext: u.p.printf("\nbts, err = msgp.ReadExtensionBytes(bts, %s)", lowered) case IDENT: - u.p.printf("\nbts, err = %s.UnmarshalMsg(bts)", lowered) + if b.Provider() != "" { + u.p.printf("\n%s, bts, err = %s().UnmarshalMsg(bts)", lowered, b.Provider()) + } else { + u.p.printf("\nbts, err = %s.UnmarshalMsg(bts)", lowered) + } default: u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", refname, b.BaseName()) } diff --git a/msgp/edit.go b/msgp/edit.go index 41f92986..7f17f4ed 100644 --- a/msgp/edit.go +++ b/msgp/edit.go @@ -1,8 +1,6 @@ package msgp -import ( - "math" -) +import "math" // Locate returns a []byte pointing to the field // in a messagepack map with the provided key. (The returned []byte diff --git a/parse/directives.go b/parse/directives.go index 73e441ef..ea0c4580 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -21,9 +21,10 @@ type passDirective func(gen.Method, []string, *gen.Printer) error // to add a directive, define a func([]string, *FileSet) error // and then add it to this list. var directives = map[string]directive{ - "shim": applyShim, - "ignore": ignore, - "tuple": astuple, + "shim": applyShim, + "ignore": ignore, + "tuple": astuple, + "intercept": applyIntercept, } var passDirectives = map[string]passDirective{ @@ -91,6 +92,7 @@ func applyShim(text []string, f *FileSet) error { infof("%s -> %s\n", name, be.Value.String()) f.findShim(name, be) + f.Identities[name] = be return nil } @@ -128,3 +130,26 @@ func astuple(text []string, f *FileSet) error { } return nil } + +//msgp:intercept {Type} using:{Provider} +func applyIntercept(text []string, f *FileSet) error { + if len(text) != 3 { + return fmt.Errorf("invalid syntax. expected 'msgp:intercept Type using:ProviderFunc'") + } + + t := strings.TrimSpace(text[1]) + using := strings.TrimPrefix(strings.TrimSpace(text[2]), "using:") + + be := gen.Ident(t) + be.SetProvider(using) + f.findShim(t, be) + if ident, ok := f.Identities[t]; ok { + if p, ok := ident.(gen.Intercepted); ok { + p.SetProvider(using) + } else { + return fmt.Errorf("attempted to intercept unexpeted type %T", ident) + } + } + + return nil +} diff --git a/parse/getast.go b/parse/getast.go index b466c287..378de834 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -256,11 +256,14 @@ func (f *FileSet) PrintTo(p *gen.Printer) error { for _, name := range names { el := f.Identities[name] el.SetVarname("z") - pushstate(el.TypeName()) - err := p.Print(el) - popstate() - if err != nil { - return err + + if el.Printable() { + pushstate(el.TypeName()) + err := p.Print(el) + popstate() + if err != nil { + return err + } } } return nil @@ -292,6 +295,7 @@ func (fs *FileSet) getTypeSpecs(f *ast.File) { *ast.ArrayType, *ast.StarExpr, *ast.MapType, + *ast.InterfaceType, *ast.Ident: fs.Specs[ts.Name.Name] = ts.Type @@ -535,8 +539,9 @@ func (fs *FileSet) parseExpr(e ast.Expr) gen.Elem { // support `interface{}` if len(e.Methods.List) == 0 { return &gen.BaseElem{Value: gen.Intf} + } else { + return &gen.IntfElem{} } - return nil default: // other types not supported return nil diff --git a/parse/inline.go b/parse/inline.go index 5dba4e56..a92d9633 100644 --- a/parse/inline.go +++ b/parse/inline.go @@ -1,8 +1,6 @@ package parse -import ( - "github.com/tinylib/msgp/gen" -) +import "github.com/tinylib/msgp/gen" // This file defines when and how we // propagate type information from @@ -49,8 +47,6 @@ func (f *FileSet) findShim(id string, be *gen.BaseElem) { } popstate() } - // we'll need this at the top level as well - f.Identities[id] = be } func (f *FileSet) nextShim(ref *gen.Elem, id string, be *gen.BaseElem) { @@ -110,7 +106,8 @@ func (f *FileSet) nextInline(ref *gen.Elem, root string) { // a type into itself typ := el.TypeName() if el.Value == gen.IDENT && typ != root { - if node, ok := f.Identities[typ]; ok && node.Complexity() < maxComplex { + isIntercepted := gen.IsIntercepted(el) + if node, ok := f.Identities[typ]; ok && node.Complexity() < maxComplex && !isIntercepted { infof("inlining %s\n", typ) // This should never happen; it will cause @@ -121,7 +118,8 @@ func (f *FileSet) nextInline(ref *gen.Elem, root string) { *ref = node.Copy() f.nextInline(ref, node.TypeName()) - } else if !ok && !el.Resolved() { + + } else if !ok && !el.Resolved() && !isIntercepted { // this is the point at which we're sure that // we've got a type that isn't a primitive, // a library builtin, or a processed type