Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add protobuf rewrite rule overrides #144

Merged
merged 3 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 143 additions & 7 deletions proto/rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,27 @@ func (f fieldset) index(i int) (int, int) {
// ParseRewriteTemplate constructs a Rewriter for a protobuf type using the
// given json template to describe the rewrite rules.
//
// The json template contains a representation of the
func ParseRewriteTemplate(typ Type, jsonTemplate []byte) (Rewriter, error) {
// The json template contains a representation of the message that is used as the
// source values to overwrite in the protobuf targeted by the resulting rewriter.
//
// The rules are an optional set of RewriterRules that can provide alternative
// Rewriters from the default used for the field type. These rules are given the
// json.RawMessage bytes from the template, and they are expected to create a
// Rewriter to be applied against the target protobuf.
func ParseRewriteTemplate(typ Type, jsonTemplate []byte, rules ...RewriterRules) (Rewriter, error) {
switch typ.Kind() {
case Struct:
return parseRewriteTemplateStruct(typ, 0, jsonTemplate)
return parseRewriteTemplateStruct(typ, 0, jsonTemplate, rules...)
default:
return nil, fmt.Errorf("cannot construct a rewrite template from a non-struct type %s", typ.Name())
}
}

func parseRewriteTemplate(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
func parseRewriteTemplate(t Type, f FieldNumber, j json.RawMessage, rule any) (Rewriter, error) {
if rwer, ok := rule.(Rewriterer); ok {
return rwer.Rewriter(t, f, j)
}

switch t.Kind() {
case Bool:
return parseRewriteTemplateBool(t, f, j)
Expand Down Expand Up @@ -184,7 +194,11 @@ func parseRewriteTemplate(t Type, f FieldNumber, j json.RawMessage) (Rewriter, e
case Map:
return parseRewriteTemplateMap(t, f, j)
case Struct:
return parseRewriteTemplateStruct(t, f, j)
sub, n, ok := [1]RewriterRules{}, 0, false
if sub[0], ok = rule.(RewriterRules); ok {
n = 1
}
return parseRewriteTemplateStruct(t, f, j, sub[:n]...)
default:
return nil, fmt.Errorf("cannot construct a rewriter from type %s", t.Name())
}
Expand Down Expand Up @@ -376,7 +390,7 @@ func parseRewriteTemplateMap(t Type, f FieldNumber, j json.RawMessage) (Rewriter
return MultiRewriter(rewriters...), nil
}

func parseRewriteTemplateStruct(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
func parseRewriteTemplateStruct(t Type, f FieldNumber, j json.RawMessage, rules ...RewriterRules) (Rewriter, error) {
template := map[string]json.RawMessage{}

if err := json.Unmarshal(j, &template); err != nil {
Expand Down Expand Up @@ -408,10 +422,18 @@ func parseRewriteTemplateStruct(t Type, f FieldNumber, j json.RawMessage) (Rewri
fields = []json.RawMessage{v}
}

var rule any
for i := range rules {
if r, ok := rules[i][f.Name]; ok {
rule = r
break
}
}

rewriters = rewriters[:0]

for _, v := range fields {
rw, err := parseRewriteTemplate(f.Type, f.Number, v)
rw, err := parseRewriteTemplate(f.Type, f.Number, v, rule)
if err != nil {
return nil, fmt.Errorf("%s: %w", k, err)
}
Expand Down Expand Up @@ -462,3 +484,117 @@ func (f *embddedRewriter) Rewrite(out, in []byte) ([]byte, error) {
copy(out[prefix:], b[:tagAndLen])
return out, nil
}

// RewriterRules defines a set of rules for overriding the Rewriter used for any
// particular field. These maps may be nested for defining rules for struct members.
//
// For example:
//
// rules := proto.RewriterRules {
// "flags": proto.BitOr[uint64]{},
// "nested": proto.RewriterRules {
// "name": myCustomRewriter,
// },
// }
type RewriterRules map[string]any

// Rewriterer is the interface for producing a Rewriter for a given Type, FieldNumber
// and json.RawMessage. The JSON value is the JSON-encoded payload that should be
// decoded to produce the appropriate Rewriter. Implementations of the Rewriterer
// interface are added to the RewriterRules to specify the rules for performing
// custom rewrite logic.
type Rewriterer interface {
Rewriter(Type, FieldNumber, json.RawMessage) (Rewriter, error)
}

// BitOr implments the Rewriterer interface for providing a bitwise-or rewrite
// logic for integers rather than replacing them. Instances of this type are
// zero-size, carrying only the generic type for creating the appropriate
// Rewriter when requested.
//
// Adding these to a RewriterRules looks like:
//
// rules := proto.RewriterRules {
// "flags": proto.BitOr[uint64]{},
// }
//
// When used as a rule when rewriting from a template, the BitOr expects a JSON-
// encoded integer passed into the Rewriter method. This parsed integer is then
// used to perform a bitwise-or against the protobuf message that is being rewritten.
//
// The above example can then be used like:
//
// template := []byte(`{"flags": 8}`) // n |= 0b1000
// rw, err := proto.ParseRewriteTemplate(typ, template, rules)
type BitOr[T integer] struct{}

// integer is the contraint used by the BitOr Rewriterer and the bitOrRW Rewriter.
// Because these perform bitwise-or operations, the types must be integer-like.
type integer interface {
~int | ~int32 | ~int64 | ~uint | ~uint32 | ~uint64
}

// Rewriter implements the Rewriterer interface. The JSON value provided to this
// method comes from the template used for rewriting. The returned Rewriter will use
// this JSON-encoded integer to perform a bitwise-or against the protobuf message
// that is being rewritten.
func (BitOr[T]) Rewriter(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
var v T
err := json.Unmarshal(j, &v)
if err != nil {
return nil, err
}
return BitOrRewriter(t, f, v)
}

// BitOrRewriter creates a bitwise-or Rewriter for a given field type and number.
// The mask is the value or'ed with values in the target protobuf.
func BitOrRewriter[T integer](t Type, f FieldNumber, mask T) (Rewriter, error) {
switch t.Kind() {
case Int32, Int64, Sint32, Sint64, Uint32, Uint64, Fix32, Fix64, Sfix32, Sfix64:
default:
return nil, fmt.Errorf("cannot construct a rewriter from type %s", t.Name())
}
return bitOrRW[T]{mask: mask, t: t, f: f}, nil
}

// bitOrRW is the Rewriter returned by the BitOr Rewriter method.
type bitOrRW[T integer] struct {
mask T
t Type
f FieldNumber
}

// Rewrite implements the Rewriter interface performing a bitwise-or between the
// template value and the input value.
func (r bitOrRW[T]) Rewrite(out, in []byte) ([]byte, error) {
var v T
if err := Unmarshal(in, &v); err != nil {
return nil, err
}

v |= r.mask

switch r.t.Kind() {
case Int32:
return r.f.Int32(int32(v)).Rewrite(out, in)
case Int64:
return r.f.Int64(int64(v)).Rewrite(out, in)
case Sint32:
return r.f.Uint32(encodeZigZag32(int32(v))).Rewrite(out, in)
case Sint64:
return r.f.Uint64(encodeZigZag64(int64(v))).Rewrite(out, in)
case Uint32, Uint64:
return r.f.Uint64(uint64(v)).Rewrite(out, in)
case Fix32:
return r.f.Fixed32(uint32(v)).Rewrite(out, in)
case Fix64:
return r.f.Fixed64(uint64(v)).Rewrite(out, in)
case Sfix32:
return r.f.Fixed32(encodeZigZag32(int32(v))).Rewrite(out, in)
case Sfix64:
return r.f.Fixed64(encodeZigZag64(int64(v))).Rewrite(out, in)
}

panic("unreachable") // Kind is validated when creating instances
}
64 changes: 64 additions & 0 deletions proto/rewrite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,70 @@ func TestParseRewriteTemplate(t *testing.T) {
}
}

func TestParseRewriteRules(t *testing.T) {
type submessage struct {
Flags uint64 `protobuf:"varint,1,opt,name=flags,proto3"`
}

type message struct {
Flags uint64 `protobuf:"varint,2,opt,name=flags,proto3"`
Subfield *submessage `protobuf:"bytes,99,opt,name=subfield,proto3"`
}

original := &message{
Flags: 0b00000001,
Subfield: &submessage{
Flags: 0b00000010,
},
}

expected := &message{
Flags: 0b00000001 | 16,
Subfield: &submessage{
Flags: 0b00000010 | 32,
},
}

rules := RewriterRules{
"flags": BitOr[uint64]{},
"subfield": RewriterRules{
"flags": BitOr[uint64]{},
},
}

rw, err := ParseRewriteTemplate(TypeOf(reflect.TypeOf(original)), []byte(`{
"flags": 16,
"subfield": {
"flags": 32
}
}`), rules)

if err != nil {
t.Fatal(err)
}

b1, err := Marshal(original)
if err != nil {
t.Fatal(err)
}

b2, err := rw.Rewrite(nil, b1)
if err != nil {
t.Fatal(err)
}

found := &message{}
if err := Unmarshal(b2, &found); err != nil {
t.Fatal(err)
}

if !reflect.DeepEqual(expected, found) {
t.Error("messages mismatch after rewrite")
t.Logf("want:\n%+v", expected)
t.Logf("got:\n%+v", found)
}
}

func BenchmarkRewrite(b *testing.B) {
type message struct {
A int
Expand Down
Loading