Skip to content

Commit

Permalink
Fix autoboxing (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
seena-stripe authored Sep 8, 2021
1 parent 7027988 commit cc2f2d0
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 10 deletions.
24 changes: 23 additions & 1 deletion go/protomodule/protomodule_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,23 @@ func (msg *protoMessage) SetField(name string, val starlark.Value) error {
return err
}

// Convert starlark.List and starlark.Dict on assignment
// Autoconvert starlark.List, starlark.Dict, wrapperspb on assignment
if fieldDesc.IsList() {
if starlarkListVal, ok := val.(*starlark.List); ok {
// To support repeated StringValue support autoboxing
// if relevant conversion, mutate incoming list
if fieldDesc.Kind() == protoreflect.MessageKind {
for i := 0; i < starlarkListVal.Len(); i++ {
msg, err := maybeConvertToWrapper(fieldDesc, starlarkListVal.Index(i))
if err != nil {
return err
}
if msg != nil {
starlarkListVal.SetIndex(i, msg)
}
}
}

// Convert starlark.List to protoRepeated
list, err := newProtoRepeatedFromList(fieldDesc, starlarkListVal)
if err != nil {
Expand All @@ -240,6 +254,14 @@ func (msg *protoMessage) SetField(name string, val starlark.Value) error {

val = mapVal
}
} else if fieldDesc.Kind() == protoreflect.MessageKind {
msg, err := maybeConvertToWrapper(fieldDesc, val)
if err != nil {
return err
}
if msg != nil {
val = msg
}
}

// Allow using msg_field = None to unset a scalar message field
Expand Down
68 changes: 68 additions & 0 deletions go/protomodule/protomodule_message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"go.starlark.net/starlark"
"go.starlark.net/syntax"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/wrapperspb"

pb "github.com/stripe/skycfg/internal/testdata/test_proto"
)
Expand Down Expand Up @@ -106,10 +107,22 @@ func TestMessageV2(t *testing.T) {
f_nested_enum = proto.package("skycfg.test_proto").MessageV2.NestedEnum.NESTED_ENUM_B,
f_oneof_a = "string in oneof",
f_bytes = "also some string",
# Autoboxed wrappers
f_BoolValue = True,
f_StringValue = "something",
f_DoubleValue = 3110.4120,
f_Int32Value = 110,
f_Int64Value = 2148483647,
f_BytesValue = "foo/bar/baz",
f_Uint32Value = 4294967295,
f_Uint64Value = 8294967295,
r_StringValue = ["s1","s2","s3"],
)`, nil)
if err != nil {
t.Fatal(err)
}

gotMsg := mustProtoMessage(t, val)
wantMsg := &pb.MessageV2{
FInt32: proto.Int32(1010),
Expand Down Expand Up @@ -142,6 +155,19 @@ func TestMessageV2(t *testing.T) {
FNestedEnum: pb.MessageV2_NESTED_ENUM_B.Enum(),
FOneof: &pb.MessageV2_FOneofA{"string in oneof"},
FBytes: []byte("also some string"),
F_BoolValue: &wrapperspb.BoolValue{Value: true},
F_StringValue: &wrapperspb.StringValue{Value: "something"},
F_DoubleValue: &wrapperspb.DoubleValue{Value: 3110.4120},
F_Int32Value: &wrapperspb.Int32Value{Value: 110},
F_Int64Value: &wrapperspb.Int64Value{Value: 2148483647},
F_BytesValue: &wrapperspb.BytesValue{Value: []byte("foo/bar/baz")},
F_Uint32Value: &wrapperspb.UInt32Value{Value: 4294967295},
F_Uint64Value: &wrapperspb.UInt64Value{Value: 8294967295},
R_StringValue: []*wrapperspb.StringValue([]*wrapperspb.StringValue{
&wrapperspb.StringValue{Value: "s1"},
&wrapperspb.StringValue{Value: "s2"},
&wrapperspb.StringValue{Value: "s3"},
}),
}
checkProtoEqual(t, wantMsg, gotMsg)

Expand Down Expand Up @@ -175,6 +201,15 @@ func TestMessageV2(t *testing.T) {
"f_oneof_a": `"string in oneof"`,
"f_oneof_b": `""`,
"f_bytes": `"also some string"`,
"f_BoolValue": `<google.protobuf.BoolValue value:true>`,
"f_StringValue": `<google.protobuf.StringValue value:"something">`,
"f_DoubleValue": `<google.protobuf.DoubleValue value:3110.412>`,
"f_Int32Value": `<google.protobuf.Int32Value value:110>`,
"f_Int64Value": `<google.protobuf.Int64Value value:2148483647>`,
"f_BytesValue": `<google.protobuf.BytesValue value:"foo/bar/baz">`,
"f_Uint32Value": `<google.protobuf.UInt32Value value:4294967295>`,
"f_Uint64Value": `<google.protobuf.UInt64Value value:8294967295>`,
"r_StringValue": `[<google.protobuf.StringValue value:"s1">, <google.protobuf.StringValue value:"s2">, <google.protobuf.StringValue value:"s3">]`,
}
attrs := val.(starlark.HasAttrs)
for attrName, wantAttr := range wantAttrs {
Expand Down Expand Up @@ -224,6 +259,17 @@ func TestMessageV3(t *testing.T) {
f_nested_enum = proto.package("skycfg.test_proto").MessageV3.NestedEnum.NESTED_ENUM_B,
f_oneof_a = "string in oneof",
f_bytes = "also some string",
# Autoboxed wrappers
f_BoolValue = True,
f_StringValue = "something",
f_DoubleValue = 3110.4120,
f_Int32Value = 110,
f_Int64Value = 2148483647,
f_BytesValue = "foo/bar/baz",
f_Uint32Value = 4294967295,
f_Uint64Value = 8294967295,
r_StringValue = ["s1","s2","s3"],
)`, nil)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -260,6 +306,19 @@ func TestMessageV3(t *testing.T) {
FNestedEnum: pb.MessageV3_NESTED_ENUM_B,
FOneof: &pb.MessageV3_FOneofA{"string in oneof"},
FBytes: []byte("also some string"),
F_BoolValue: &wrapperspb.BoolValue{Value: true},
F_StringValue: &wrapperspb.StringValue{Value: "something"},
F_DoubleValue: &wrapperspb.DoubleValue{Value: 3110.4120},
F_Int32Value: &wrapperspb.Int32Value{Value: 110},
F_Int64Value: &wrapperspb.Int64Value{Value: 2148483647},
F_BytesValue: &wrapperspb.BytesValue{Value: []byte("foo/bar/baz")},
F_Uint32Value: &wrapperspb.UInt32Value{Value: 4294967295},
F_Uint64Value: &wrapperspb.UInt64Value{Value: 8294967295},
R_StringValue: []*wrapperspb.StringValue([]*wrapperspb.StringValue{
&wrapperspb.StringValue{Value: "s1"},
&wrapperspb.StringValue{Value: "s2"},
&wrapperspb.StringValue{Value: "s3"},
}),
}
checkProtoEqual(t, wantMsg, gotMsg)

Expand Down Expand Up @@ -293,6 +352,15 @@ func TestMessageV3(t *testing.T) {
"f_oneof_a": `"string in oneof"`,
"f_oneof_b": `""`,
"f_bytes": `"also some string"`,
"f_BoolValue": `<google.protobuf.BoolValue value:true>`,
"f_StringValue": `<google.protobuf.StringValue value:"something">`,
"f_DoubleValue": `<google.protobuf.DoubleValue value:3110.412>`,
"f_Int32Value": `<google.protobuf.Int32Value value:110>`,
"f_Int64Value": `<google.protobuf.Int64Value value:2148483647>`,
"f_BytesValue": `<google.protobuf.BytesValue value:"foo/bar/baz">`,
"f_Uint32Value": `<google.protobuf.UInt32Value value:4294967295>`,
"f_Uint64Value": `<google.protobuf.UInt64Value value:8294967295>`,
"r_StringValue": `[<google.protobuf.StringValue value:"s1">, <google.protobuf.StringValue value:"s2">, <google.protobuf.StringValue value:"s3">]`,
}
attrs := val.(starlark.HasAttrs)
for attrName, wantAttr := range wantAttrs {
Expand Down
8 changes: 0 additions & 8 deletions go/protomodule/type_conversions.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,6 @@ func scalarValueFromStarlark(fieldDesc protoreflect.FieldDescriptor, val starlar
return protoreflect.Value{}, fmt.Errorf("ValueError: value %v overflows type \"uint32\".", valInt)
}
case protoreflect.MessageKind:
msg, err := maybeConvertToWrapper(fieldDesc, val)
if err != nil {
return protoreflect.Value{}, err
}
if msg != nil {
return protoreflect.ValueOf(msg.toProtoMessage().ProtoReflect()), nil
}

if msg, ok := val.(*protoMessage); ok {
if msg.Type() == typeName(fieldDesc) {
return protoreflect.ValueOf(msg.toProtoMessage().ProtoReflect()), nil
Expand Down
15 changes: 14 additions & 1 deletion internal/testdata/test_proto/test_proto_v2.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ syntax = "proto2";
option go_package = "github.com/stripe/skycfg/internal/testdata/test_proto";
package skycfg.test_proto;

import "google/protobuf/wrappers.proto";

message MessageV2 {
optional int32 f_int32 = 1;
optional int64 f_int64 = 2;
Expand Down Expand Up @@ -60,7 +62,18 @@ message MessageV2 {

optional bytes f_bytes = 19;

// NEXT: 20
optional google.protobuf.BoolValue f_BoolValue = 20;
optional google.protobuf.StringValue f_StringValue = 21;
optional google.protobuf.DoubleValue f_DoubleValue = 22;
optional google.protobuf.Int32Value f_Int32Value = 23;
optional google.protobuf.Int64Value f_Int64Value = 24;
optional google.protobuf.BytesValue f_BytesValue = 25;
optional google.protobuf.UInt32Value f_Uint32Value = 26;
optional google.protobuf.UInt64Value f_Uint64Value = 27;

repeated google.protobuf.StringValue r_StringValue = 28;

// NEXT: 29
}

enum ToplevelEnumV2 {
Expand Down

0 comments on commit cc2f2d0

Please sign in to comment.