-
Notifications
You must be signed in to change notification settings - Fork 14
/
encoding.go
140 lines (120 loc) · 3.78 KB
/
encoding.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package gorums
import (
"fmt"
"github.com/relab/gorums/ordering"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)
// ContentSubtype is the subtype used by gorums when sending messages via gRPC.
const ContentSubtype = "gorums"
type gorumsMsgType uint8
const (
requestType gorumsMsgType = iota + 1
responseType
)
// Message encapsulates a protobuf message and metadata.
//
// This struct should be used by generated code only.
type Message struct {
Metadata *ordering.Metadata
Message protoreflect.ProtoMessage
msgType gorumsMsgType
}
// newMessage creates a new Message struct for unmarshaling.
// msgType specifies the message type to be unmarshaled.
func newMessage(msgType gorumsMsgType) *Message {
return &Message{Metadata: &ordering.Metadata{}, msgType: msgType}
}
// Codec is the gRPC codec used by gorums.
type Codec struct {
marshaler proto.MarshalOptions
unmarshaler proto.UnmarshalOptions
}
// NewCodec returns a new Codec.
func NewCodec() *Codec {
return &Codec{
marshaler: proto.MarshalOptions{AllowPartial: true},
unmarshaler: proto.UnmarshalOptions{AllowPartial: true},
}
}
// Name returns the name of the Codec.
func (c Codec) Name() string {
return ContentSubtype
}
func (c Codec) String() string {
return ContentSubtype
}
// Marshal marshals the message m into a byte slice.
func (c Codec) Marshal(m interface{}) (b []byte, err error) {
switch msg := m.(type) {
case *Message:
return c.gorumsMarshal(msg)
case protoreflect.ProtoMessage:
return c.marshaler.Marshal(msg)
default:
return nil, fmt.Errorf("gorums: cannot marshal message of type '%T'", m)
}
}
// gorumsMarshal marshals a metadata and a data message into a single byte slice.
func (c Codec) gorumsMarshal(msg *Message) (b []byte, err error) {
mdSize := c.marshaler.Size(msg.Metadata)
b = protowire.AppendVarint(b, uint64(mdSize))
b, err = c.marshaler.MarshalAppend(b, msg.Metadata)
if err != nil {
return nil, err
}
msgSize := c.marshaler.Size(msg.Message)
b = protowire.AppendVarint(b, uint64(msgSize))
b, err = c.marshaler.MarshalAppend(b, msg.Message)
if err != nil {
return nil, err
}
return b, nil
}
// Unmarshal unmarshals a byte slice into m.
func (c Codec) Unmarshal(b []byte, m interface{}) (err error) {
switch msg := m.(type) {
case *Message:
return c.gorumsUnmarshal(b, msg)
case protoreflect.ProtoMessage:
return c.unmarshaler.Unmarshal(b, msg)
default:
return fmt.Errorf("gorums: cannot unmarshal message of type '%T'", m)
}
}
// gorumsUnmarshal extracts metadata and message data from b and places the result in msg.
func (c Codec) gorumsUnmarshal(b []byte, msg *Message) (err error) {
// unmarshal metadata
mdBuf, mdLen := protowire.ConsumeBytes(b)
err = c.unmarshaler.Unmarshal(mdBuf, msg.Metadata)
if err != nil {
return err
}
// get method descriptor from registry
desc, err := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(msg.Metadata.Method))
if err != nil {
return err
}
methodDesc := desc.(protoreflect.MethodDescriptor)
// get message name depending on whether we are creating a request or response message
var messageName protoreflect.FullName
switch msg.msgType {
case requestType:
messageName = methodDesc.Input().FullName()
case responseType:
messageName = methodDesc.Output().FullName()
default:
return fmt.Errorf("gorums: unknown message type %d", msg.msgType)
}
// now get the message type from the types registry
msgType, err := protoregistry.GlobalTypes.FindMessageByName(messageName)
if err != nil {
return err
}
msg.Message = msgType.New().Interface()
// unmarshal message
msgBuf, _ := protowire.ConsumeBytes(b[mdLen:])
return c.unmarshaler.Unmarshal(msgBuf, msg.Message)
}