-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add rapidproto generator (#14849)
- Loading branch information
1 parent
deeb4bd
commit 72db75b
Showing
7 changed files
with
259 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
package rapidproto | ||
|
||
import ( | ||
"fmt" | ||
|
||
"google.golang.org/protobuf/proto" | ||
"google.golang.org/protobuf/reflect/protoreflect" | ||
"google.golang.org/protobuf/reflect/protoregistry" | ||
"gotest.tools/v3/assert" | ||
"pgregory.net/rapid" | ||
) | ||
|
||
func MessageGenerator[T proto.Message](x T, options GeneratorOptions) *rapid.Generator[T] { | ||
msgType := x.ProtoReflect().Type() | ||
return rapid.Custom(func(t *rapid.T) T { | ||
msg := msgType.New() | ||
|
||
options.setFields(t, msg, 0) | ||
|
||
return msg.Interface().(T) | ||
}) | ||
} | ||
|
||
type GeneratorOptions struct { | ||
AnyTypeURLs []string | ||
Resolver protoregistry.MessageTypeResolver | ||
} | ||
|
||
const depthLimit = 10 | ||
|
||
func (opts GeneratorOptions) setFields(t *rapid.T, msg protoreflect.Message, depth int) bool { | ||
// to avoid stack overflow we limit the depth of nested messages | ||
if depth > depthLimit { | ||
return false | ||
} | ||
|
||
descriptor := msg.Descriptor() | ||
fullName := descriptor.FullName() | ||
switch fullName { | ||
case timestampFullName: | ||
opts.genTimestamp(t, msg) | ||
return true | ||
case durationFullName: | ||
opts.genDuration(t, msg) | ||
return true | ||
case anyFullName: | ||
return opts.genAny(t, msg, depth) | ||
case fieldMaskFullName: | ||
opts.genFieldMask(t, msg) | ||
return true | ||
default: | ||
fields := descriptor.Fields() | ||
n := fields.Len() | ||
for i := 0; i < n; i++ { | ||
field := fields.Get(i) | ||
if !rapid.Bool().Draw(t, fmt.Sprintf("gen-%s", field.Name())) { | ||
continue | ||
} | ||
|
||
opts.setFieldValue(t, msg, field, depth) | ||
} | ||
return true | ||
} | ||
} | ||
|
||
const ( | ||
timestampFullName = "google.protobuf.Timestamp" | ||
durationFullName = "google.protobuf.Duration" | ||
anyFullName = "google.protobuf.Any" | ||
fieldMaskFullName = "google.protobuf.FieldMask" | ||
) | ||
|
||
func (opts GeneratorOptions) setFieldValue(t *rapid.T, msg protoreflect.Message, field protoreflect.FieldDescriptor, depth int) { | ||
name := string(field.Name()) | ||
kind := field.Kind() | ||
|
||
switch { | ||
case field.IsList(): | ||
list := msg.Mutable(field).List() | ||
n := rapid.IntRange(0, 10).Draw(t, fmt.Sprintf("%sN", name)) | ||
for i := 0; i < n; i++ { | ||
if kind == protoreflect.MessageKind || kind == protoreflect.GroupKind { | ||
if !opts.setFields(t, list.AppendMutable().Message(), depth+1) { | ||
list.Truncate(i) | ||
} | ||
} else { | ||
list.Append(opts.genScalarFieldValue(t, field, fmt.Sprintf("%s%d", name, i))) | ||
} | ||
} | ||
case field.IsMap(): | ||
m := msg.Mutable(field).Map() | ||
n := rapid.IntRange(0, 10).Draw(t, fmt.Sprintf("%sN", name)) | ||
for i := 0; i < n; i++ { | ||
keyField := field.MapKey() | ||
valueField := field.MapValue() | ||
valueKind := valueField.Kind() | ||
key := opts.genScalarFieldValue(t, keyField, fmt.Sprintf("%s%d-key", name, i)) | ||
if valueKind == protoreflect.MessageKind || valueKind == protoreflect.GroupKind { | ||
if !opts.setFields(t, m.Mutable(key.MapKey()).Message(), depth+1) { | ||
m.Clear(key.MapKey()) | ||
} | ||
} else { | ||
value := opts.genScalarFieldValue(t, valueField, fmt.Sprintf("%s%d-key", name, i)) | ||
m.Set(key.MapKey(), value) | ||
} | ||
} | ||
default: | ||
if kind == protoreflect.MessageKind || kind == protoreflect.GroupKind { | ||
if !opts.setFields(t, msg.Mutable(field).Message(), depth+1) { | ||
msg.Clear(field) | ||
} | ||
} else { | ||
msg.Set(field, opts.genScalarFieldValue(t, field, name)) | ||
} | ||
} | ||
} | ||
|
||
func (opts GeneratorOptions) genScalarFieldValue(t *rapid.T, field protoreflect.FieldDescriptor, name string) protoreflect.Value { | ||
switch field.Kind() { | ||
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: | ||
return protoreflect.ValueOfInt32(rapid.Int32().Draw(t, name)) | ||
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: | ||
return protoreflect.ValueOfUint32(rapid.Uint32().Draw(t, name)) | ||
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: | ||
return protoreflect.ValueOfInt64(rapid.Int64().Draw(t, name)) | ||
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: | ||
return protoreflect.ValueOfUint64(rapid.Uint64().Draw(t, name)) | ||
case protoreflect.BoolKind: | ||
return protoreflect.ValueOfBool(rapid.Bool().Draw(t, name)) | ||
case protoreflect.BytesKind: | ||
return protoreflect.ValueOfBytes(rapid.SliceOf(rapid.Byte()).Draw(t, name)) | ||
case protoreflect.FloatKind: | ||
return protoreflect.ValueOfFloat32(rapid.Float32().Draw(t, name)) | ||
case protoreflect.DoubleKind: | ||
return protoreflect.ValueOfFloat64(rapid.Float64().Draw(t, name)) | ||
case protoreflect.EnumKind: | ||
enumValues := field.Enum().Values() | ||
val := rapid.Int32Range(0, int32(enumValues.Len()-1)).Draw(t, name) | ||
return protoreflect.ValueOfEnum(protoreflect.EnumNumber(val)) | ||
case protoreflect.StringKind: | ||
return protoreflect.ValueOfString(rapid.String().Draw(t, name)) | ||
default: | ||
t.Fatalf("unexpected %v", field) | ||
return protoreflect.Value{} | ||
} | ||
} | ||
|
||
const ( | ||
secondsName = "seconds" | ||
nanosName = "nanos" | ||
) | ||
|
||
func (opts GeneratorOptions) genTimestamp(t *rapid.T, msg protoreflect.Message) { | ||
seconds := rapid.Int64Range(-9999999999, 9999999999).Draw(t, "seconds") | ||
nanos := rapid.Int32Range(0, 999999999).Draw(t, "nanos") | ||
setSecondsNanosFields(t, msg, seconds, nanos) | ||
} | ||
|
||
func (opts GeneratorOptions) genDuration(t *rapid.T, msg protoreflect.Message) { | ||
seconds := rapid.Int64Range(0, 315576000000).Draw(t, "seconds") | ||
nanos := rapid.Int32Range(0, 999999999).Draw(t, "nanos") | ||
setSecondsNanosFields(t, msg, seconds, nanos) | ||
} | ||
|
||
func setSecondsNanosFields(t *rapid.T, message protoreflect.Message, seconds int64, nanos int32) { | ||
fields := message.Descriptor().Fields() | ||
|
||
secondsField := fields.ByName(secondsName) | ||
assert.Assert(t, secondsField != nil) | ||
message.Set(secondsField, protoreflect.ValueOfInt64(seconds)) | ||
|
||
nanosField := fields.ByName(nanosName) | ||
assert.Assert(t, nanosField != nil) | ||
message.Set(nanosField, protoreflect.ValueOfInt32(nanos)) | ||
} | ||
|
||
const ( | ||
typeURLName = "type_url" | ||
valueName = "value" | ||
) | ||
|
||
func (opts GeneratorOptions) genAny(t *rapid.T, msg protoreflect.Message, depth int) bool { | ||
if len(opts.AnyTypeURLs) == 0 { | ||
return false | ||
} | ||
|
||
fields := msg.Descriptor().Fields() | ||
|
||
typeURL := rapid.SampledFrom(opts.AnyTypeURLs).Draw(t, "type_url") | ||
typ, err := opts.Resolver.FindMessageByURL(typeURL) | ||
assert.NilError(t, err) | ||
|
||
typeURLField := fields.ByName(typeURLName) | ||
assert.Assert(t, typeURLField != nil) | ||
msg.Set(typeURLField, protoreflect.ValueOfString(typeURL)) | ||
|
||
valueMsg := typ.New() | ||
opts.setFields(t, valueMsg, depth+1) | ||
valueBz, err := proto.Marshal(valueMsg.Interface()) | ||
assert.NilError(t, err) | ||
|
||
valueField := fields.ByName(valueName) | ||
assert.Assert(t, valueField != nil) | ||
msg.Set(valueField, protoreflect.ValueOfBytes(valueBz)) | ||
|
||
return true | ||
} | ||
|
||
const ( | ||
pathsName = "paths" | ||
) | ||
|
||
func (opts GeneratorOptions) genFieldMask(t *rapid.T, msg protoreflect.Message) { | ||
paths := rapid.SliceOfN(rapid.StringMatching("[a-z]+([.][a-z]+){0,2}"), 1, 5).Draw(t, "paths") | ||
pathsField := msg.Descriptor().Fields().ByName(pathsName) | ||
assert.Assert(t, pathsField != nil) | ||
pathsList := msg.NewField(pathsField).List() | ||
for _, path := range paths { | ||
pathsList.Append(protoreflect.ValueOfString(path)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package rapidproto_test | ||
|
||
import ( | ||
"fmt" | ||
"testing" | ||
|
||
"google.golang.org/protobuf/encoding/protojson" | ||
"google.golang.org/protobuf/proto" | ||
"gotest.tools/v3/assert" | ||
"gotest.tools/v3/golden" | ||
"pgregory.net/rapid" | ||
|
||
"github.com/cosmos/cosmos-proto/testpb" | ||
|
||
"github.com/cosmos/cosmos-sdk/testutil/rapidproto" | ||
) | ||
|
||
// TestRegression checks that the generator still produces the same output | ||
// for the same random seeds, assuming that this data has been hand expected | ||
// to generally look good. | ||
func TestRegression(t *testing.T) { | ||
gen := rapidproto.MessageGenerator(&testpb.A{}, rapidproto.GeneratorOptions{}) | ||
for i := 0; i < 5; i++ { | ||
testRegressionSeed(t, i, gen) | ||
} | ||
} | ||
|
||
func testRegressionSeed[X proto.Message](t *testing.T, seed int, generator *rapid.Generator[X]) { | ||
x := generator.Example(seed) | ||
bz, err := protojson.Marshal(x) | ||
assert.NilError(t, err) | ||
golden.Assert(t, string(bz), fmt.Sprintf("seed%d.json", seed)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"enum":"Two", "someBoolean":true, "INT32":6, "SINT32":-53, "INT64":"-261", "SFIXED32":3, "FIXED32":65302, "FIXED64":"45044", "STRING":"~Âaႃ#", "MESSAGE":{"x":"ʰ="}, "MAP":{"":{"x":"௹"}, "%󠇯º$&.":{"x":"-"}, "=A":{}, "AA|𞀠":{"x":"a\u0000ๆ"}}, "LIST":[{}], "ONEOFSTRING":"", "imported":{}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"UINT32":177, "INT64":"-139958413", "SFIXED32":41418, "FIXED32":25381940, "FLOAT":-8.336453e+31, "SFIXED64":"-2503553836720", "DOUBLE":-0.03171187036377887, "STRING":"?˄~ע", "MESSAGE":{"x":"dDž#"}, "MAP":{"Ⱥa<":{"x":"+["}, "֑Ⱥ|@!`":{}}, "ONEOFSTRING":"\u0012\t?A", "imported":{}, "type":"A�=*ى~~Ⱥ*ᾈാȺAᶊ?"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"INT32":-48, "UINT32":246, "INT64":"-21558176502", "SING64":"5030347", "UINT64":"28", "FIXED32":92, "DOUBLE":2.3547259926790202e-142, "STRING":"ಾ", "LIST":[{}, {}, {}, {}, {"x":" ᾚ DzA{˭҄\nA ^$?ᾦ,:<\"?_\u0014;|"}], "ONEOFSTRING":"𝟠", "LISTENUM":["Two", "One", "One"]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"INT32":22525032, "SINT32":897, "INT64":"-301128487533312", "SFIXED64":"-71", "FIXED64":"14", "DOUBLE":-2.983041182946181, "STRING":"-A^'", "MESSAGE":{"x":"#ऻ;́\r⋁"}, "LIST":[{}, {}, {}, {}, {}], "ONEOFSTRING":"", "imported":{}, "type":"₩\u0000^৴~౽ NjAৈ⃠𝖜ೄ"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"SINT32":1, "INT64":"-9223372036854775808", "SING64":"1", "FLOAT":-0.00013906474, "SFIXED64":"71414010", "STRING":"ף̂", "MESSAGE":{"x":" "}, "LIST":[{}], "ONEOFSTRING":"#¯∑Ⱥ�", "LISTENUM":["One", "One", "Two", "Two", "One", "One", "One", "Two"], "imported":{}, "type":"\u001b<ʰ+`𑱐@\u001b*Dž\u0000#₻\u0000"} |