Skip to content

Commit

Permalink
protoparse: fix extension resolution in custom options to match protoc (
Browse files Browse the repository at this point in the history
  • Loading branch information
jhump authored Feb 4, 2022
1 parent d4949d2 commit 260eab9
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 43 deletions.
137 changes: 101 additions & 36 deletions desc/protoparse/linker.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (l *linker) resolveReferences() error {
prefix += "."
}
if fd.Options != nil {
if err := l.resolveOptions(r, fd, "file", fd.GetName(), proto.MessageName(fd.Options), fd.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "file", fd.GetName(), fd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand Down Expand Up @@ -342,14 +342,14 @@ func (l *linker) resolveReferences() error {
func (l *linker) resolveEnumTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, ed *dpb.EnumDescriptorProto, scopes []scope) error {
enumFqn := prefix + ed.GetName()
if ed.Options != nil {
if err := l.resolveOptions(r, fd, "enum", enumFqn, proto.MessageName(ed.Options), ed.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "enum", enumFqn, ed.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
for _, evd := range ed.Value {
if evd.Options != nil {
evFqn := enumFqn + "." + evd.GetName()
if err := l.resolveOptions(r, fd, "enum value", evFqn, proto.MessageName(evd.Options), evd.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "enum value", evFqn, evd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand All @@ -359,16 +359,21 @@ func (l *linker) resolveEnumTypes(r *parseResult, fd *dpb.FileDescriptorProto, p

func (l *linker) resolveMessageTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, md *dpb.DescriptorProto, scopes []scope) error {
fqn := prefix + md.GetName()
scope := messageScope(fqn, isProto3(fd), l, fd)
scopes = append(scopes, scope)
prefix = fqn + "."

// Strangely, when protoc resolves extension names, it uses the *enclosing* scope
// instead of the message's scope. So if the message contains an extension named "i",
// an option cannot refer to it as simply "i" but must qualify it (at a minimum "Msg.i").
// So we don't add this messages scope to our scopes slice until *after* we do options.
if md.Options != nil {
if err := l.resolveOptions(r, fd, "message", fqn, proto.MessageName(md.Options), md.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "message", fqn, md.Options.UninterpretedOption, scopes); err != nil {
return err
}
}

scope := messageScope(fqn, isProto3(fd), l, fd)
scopes = append(scopes, scope)
prefix = fqn + "."

for _, nmd := range md.NestedType {
if err := l.resolveMessageTypes(r, fd, prefix, nmd, scopes); err != nil {
return err
Expand All @@ -387,7 +392,7 @@ func (l *linker) resolveMessageTypes(r *parseResult, fd *dpb.FileDescriptorProto
for _, ood := range md.OneofDecl {
if ood.Options != nil {
ooName := fmt.Sprintf("%s.%s", fqn, ood.GetName())
if err := l.resolveOptions(r, fd, "oneof", ooName, proto.MessageName(ood.Options), ood.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "oneof", ooName, ood.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand All @@ -400,7 +405,7 @@ func (l *linker) resolveMessageTypes(r *parseResult, fd *dpb.FileDescriptorProto
for _, er := range md.ExtensionRange {
if er.Options != nil {
erName := fmt.Sprintf("%s:%d-%d", fqn, er.GetStart(), er.GetEnd()-1)
if err := l.resolveOptions(r, fd, "extension range", erName, proto.MessageName(er.Options), er.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "extension range", erName, er.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand Down Expand Up @@ -459,7 +464,7 @@ func (l *linker) resolveFieldTypes(r *parseResult, fd *dpb.FileDescriptorProto,
}

if fld.Options != nil {
if err := l.resolveOptions(r, fd, elemType, thisName, proto.MessageName(fld.Options), fld.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, elemType, thisName, fld.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand Down Expand Up @@ -501,7 +506,7 @@ func (l *linker) resolveFieldTypes(r *parseResult, fd *dpb.FileDescriptorProto,
func (l *linker) resolveServiceTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, sd *dpb.ServiceDescriptorProto, scopes []scope) error {
svcFqn := prefix + sd.GetName()
if sd.Options != nil {
if err := l.resolveOptions(r, fd, "service", svcFqn, proto.MessageName(sd.Options), sd.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "service", svcFqn, sd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand All @@ -512,7 +517,7 @@ func (l *linker) resolveServiceTypes(r *parseResult, fd *dpb.FileDescriptorProto

for _, mtd := range sd.Method {
if mtd.Options != nil {
if err := l.resolveOptions(r, fd, "method", svcFqn+"."+mtd.GetName(), proto.MessageName(mtd.Options), mtd.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "method", svcFqn+"."+mtd.GetName(), mtd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand Down Expand Up @@ -558,48 +563,108 @@ func (l *linker) resolveServiceTypes(r *parseResult, fd *dpb.FileDescriptorProto
return nil
}

func (l *linker) resolveOptions(r *parseResult, fd *dpb.FileDescriptorProto, elemType, elemName, optType string, opts []*dpb.UninterpretedOption, scopes []scope) error {
var scope string
if elemType != "file" {
scope = fmt.Sprintf("%s %s: ", elemType, elemName)
func (l *linker) resolveOptions(r *parseResult, fd *dpb.FileDescriptorProto, elemType, elemName string, opts []*dpb.UninterpretedOption, scopes []scope) error {
mc := &messageContext{
res: r,
elementName: elemName,
elementType: elemType,
}
opts:
for _, opt := range opts {
// resolve any extension names found in option names
for _, nm := range opt.Name {
if nm.GetIsExtension() {
node := r.getOptionNamePartNode(nm)
fqn, dsc, _ := l.resolve(fd, nm.GetNamePart(), false, scopes)
if dsc == nil {
if err := l.errs.handleErrorWithPos(node.Start(), "%sunknown extension %s", scope, nm.GetNamePart()); err != nil {
return err
}
continue opts
}
if dsc == sentinelMissingSymbol {
if err := l.errs.handleErrorWithPos(node.Start(), "%sunknown extension %s; resolved to %s which is not defined; consider using a leading dot", scope, nm.GetNamePart(), fqn); err != nil {
fqn, err := l.resolveExtensionName(nm.GetNamePart(), fd, scopes)
if err != nil {
node := r.getOptionNamePartNode(nm)
if err := l.errs.handleErrorWithPos(node.Start(), "%v%v", mc, err); err != nil {
return err
}
continue opts
}
if ext, ok := dsc.(*dpb.FieldDescriptorProto); !ok {
otherType := descriptorType(dsc)
if err := l.errs.handleErrorWithPos(node.Start(), "%sinvalid extension: %s is a %s, not an extension", scope, nm.GetNamePart(), otherType); err != nil {
return err
}
continue opts
} else if ext.GetExtendee() == "" {
if err := l.errs.handleErrorWithPos(node.Start(), "%sinvalid extension: %s is a field but not an extension", scope, nm.GetNamePart()); err != nil {
nm.NamePart = proto.String(fqn)
}
}
// also resolve any extension names found inside message literals in option values
mc.option = opt
optVal := r.getOptionNode(opt).GetValue()
if err := l.resolveOptionValue(r, mc, fd, optVal, scopes); err != nil {
return err
}
mc.option = nil
}
return nil
}

func (l *linker) resolveOptionValue(r *parseResult, mc *messageContext, fd *dpb.FileDescriptorProto, val ast.ValueNode, scopes []scope) error {
optVal := val.Value()
switch optVal := optVal.(type) {
case []ast.ValueNode:
origPath := mc.optAggPath
defer func() {
mc.optAggPath = origPath
}()
for i, v := range optVal {
mc.optAggPath = fmt.Sprintf("%s[%d]", origPath, i)
if err := l.resolveOptionValue(r, mc, fd, v, scopes); err != nil {
return err
}
}
case []*ast.MessageFieldNode:
origPath := mc.optAggPath
defer func() {
mc.optAggPath = origPath
}()
for _, fld := range optVal {
// check for extension name
if fld.Name.IsExtension() {
fqn, err := l.resolveExtensionName(string(fld.Name.Name.AsIdentifier()), fd, scopes)
if err != nil {
if err := l.errs.handleErrorWithPos(fld.Name.Name.Start(), "%v%v", mc, err); err != nil {
return err
}
continue opts
} else {
r.optionQualifiedNames[fld.Name.Name] = fqn
}
nm.NamePart = proto.String("." + fqn)
}

// recurse into value
mc.optAggPath = origPath
if origPath != "" {
mc.optAggPath += "."
}
if fld.Name.IsExtension() {
mc.optAggPath = fmt.Sprintf("%s[%s]", mc.optAggPath, string(fld.Name.Name.AsIdentifier()))
} else {
mc.optAggPath = fmt.Sprintf("%s%s", mc.optAggPath, string(fld.Name.Name.AsIdentifier()))
}

if err := l.resolveOptionValue(r, mc, fd, fld.Val, scopes); err != nil {
return err
}
}
}

return nil
}

func (l *linker) resolveExtensionName(name string, fd *dpb.FileDescriptorProto, scopes []scope) (string, error) {
fqn, dsc, _ := l.resolve(fd, name, false, scopes)
if dsc == nil {
return "", fmt.Errorf("unknown extension %s", name)
}
if dsc == sentinelMissingSymbol {
return "", fmt.Errorf("unknown extension %s; resolved to %s which is not defined; consider using a leading dot", name, fqn)
}
if ext, ok := dsc.(*dpb.FieldDescriptorProto); !ok {
otherType := descriptorType(dsc)
return "", fmt.Errorf("invalid extension: %s is a %s, not an extension", name, otherType)
} else if ext.GetExtendee() == "" {
return "", fmt.Errorf("invalid extension: %s is a field but not an extension", name)
}
return "." + fqn, nil
}

func (l *linker) resolve(fd *dpb.FileDescriptorProto, name string, onlyTypes bool, scopes []scope) (fqn string, element proto.Message, proto3 bool) {
if strings.HasPrefix(name, ".") {
// already fully-qualified
Expand Down
121 changes: 119 additions & 2 deletions desc/protoparse/linker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ func TestLinkerValidation(t *testing.T) {
"extend google.protobuf.MessageOptions { optional Foo foo = 10001; }\n" +
"message Baz { option (foo) = { [Bar]< name: \"abc\" > }; }\n",
},
"foo.proto:6:30: message Baz: option (foo): field Bar not found",
"foo.proto:6:33: message Baz: option (foo): invalid extension: Bar is a message, not an extension",
},
{
map[string]string{
Expand Down Expand Up @@ -559,13 +559,130 @@ func TestLinkerValidation(t *testing.T) {
},
{
map[string]string{
"a.proto": "syntax=\"proto3\";\nmessage m{\n" +
"a.proto": "syntax=\"proto3\";\n" +
"message m{\n" +
" string z = 1;\n" +
" oneof z{int64 b=2;}\n" +
"}",
},
`a.proto:3:3: duplicate symbol m.z: already defined as oneof`,
},
{
map[string]string{
"test.proto": "syntax=\"proto2\";\n" +
"package foo.bar;\n" +
"import \"google/protobuf/descriptor.proto\";\n" +
"message a { extensions 1 to 100; }\n" +
"extend google.protobuf.MessageOptions { optional a msga = 10000; }\n" +
"message b {\n" +
" message c {\n" +
" extend a { repeated int32 i = 1; repeated float f = 2; }\n" +
" }\n" +
" option (msga) = {\n" +
" [foo.bar.b.c.i]: 123\n" +
" [bar.b.c.i]: 234\n" +
" [b.c.i]: 345\n" +
" };\n" +
" option (msga).(foo.bar.b.c.f) = 1.23;\n" +
" option (msga).(bar.b.c.f) = 2.34;\n" +
" option (msga).(b.c.f) = 3.45;\n" +
"}",
},
"", // should succeed
},
{
map[string]string{
"test.proto": "syntax=\"proto2\";\n" +
"package foo.bar;\n" +
"import \"google/protobuf/descriptor.proto\";\n" +
"message a { extensions 1 to 100; }\n" +
"message b { extensions 1 to 100; }\n" +
"extend google.protobuf.MessageOptions { optional a msga = 10000; }\n" +
"message c {\n" +
" extend a { optional b b = 1; }\n" +
" extend b { repeated int32 i = 1; repeated float f = 2; }\n" +
" option (msga) = {\n" +
" [foo.bar.c.b] {\n" +
" [foo.bar.c.i]: 123\n" +
" [bar.c.i]: 234\n" +
" [c.i]: 345\n" +
" }\n" +
" };\n" +
" option (msga).(foo.bar.c.b).(foo.bar.c.f) = 1.23;\n" +
" option (msga).(foo.bar.c.b).(bar.c.f) = 2.34;\n" +
" option (msga).(foo.bar.c.b).(c.f) = 3.45;\n" +
"}",
},
"", // should succeed
},
{
map[string]string{
"test.proto": "syntax=\"proto2\";\n" +
"package foo.bar;\n" +
"import \"google/protobuf/descriptor.proto\";\n" +
"message a { extensions 1 to 100; }\n" +
"extend google.protobuf.MessageOptions { optional a msga = 10000; }\n" +
"message b {\n" +
" message c {\n" +
" extend a { repeated int32 i = 1; repeated float f = 2; }\n" +
" }\n" +
" option (msga) = {\n" +
" [c.i]: 456\n" +
" };\n" +
"}",
},
"test.proto:11:6: message foo.bar.b: option (foo.bar.msga): unknown extension c.i",
},
{
map[string]string{
"test.proto": "syntax=\"proto2\";\n" +
"package foo.bar;\n" +
"import \"google/protobuf/descriptor.proto\";\n" +
"message a { extensions 1 to 100; }\n" +
"extend google.protobuf.MessageOptions { optional a msga = 10000; }\n" +
"message b {\n" +
" message c {\n" +
" extend a { repeated int32 i = 1; repeated float f = 2; }\n" +
" }\n" +
" option (msga) = {\n" +
" [i]: 567\n" +
" };\n" +
"}",
},
"test.proto:11:6: message foo.bar.b: option (foo.bar.msga): unknown extension i",
},
{
map[string]string{
"test.proto": "syntax=\"proto2\";\n" +
"package foo.bar;\n" +
"import \"google/protobuf/descriptor.proto\";\n" +
"message a { extensions 1 to 100; }\n" +
"extend google.protobuf.MessageOptions { optional a msga = 10000; }\n" +
"message b {\n" +
" message c {\n" +
" extend a { repeated int32 i = 1; repeated float f = 2; }\n" +
" }\n" +
" option (msga).(c.f) = 4.56;\n" +
"}",
},
"test.proto:10:17: message foo.bar.b: unknown extension c.f",
},
{
map[string]string{
"test.proto": "syntax=\"proto2\";\n" +
"package foo.bar;\n" +
"import \"google/protobuf/descriptor.proto\";\n" +
"message a { extensions 1 to 100; }\n" +
"extend google.protobuf.MessageOptions { optional a msga = 10000; }\n" +
"message b {\n" +
" message c {\n" +
" extend a { repeated int32 i = 1; repeated float f = 2; }\n" +
" }\n" +
" option (msga).(f) = 5.67;\n" +
"}",
},
"test.proto:10:17: message foo.bar.b: unknown extension f",
},
{
map[string]string{
"a.proto": "syntax=\"proto3\";\nmessage m{\n" +
Expand Down
7 changes: 6 additions & 1 deletion desc/protoparse/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -1348,10 +1348,15 @@ func fieldValue(res *parseResult, mc *messageContext, fld fldDescriptorish, val
}
var ffld *desc.FieldDescriptor
if a.Name.IsExtension() {
n := string(a.Name.Name.AsIdentifier())
n := res.optionQualifiedNames[a.Name.Name]
if n == "" {
// this should not be possible!
n = string(a.Name.Name.AsIdentifier())
}
ffld = findExtension(mc.file, n, false, map[fileDescriptorish]struct{}{})
if ffld == nil {
// may need to qualify with package name
// (this should not be necessary!)
pkg := mc.file.GetPackage()
if pkg != "" {
ffld = findExtension(mc.file, pkg+"."+n, false, map[fileDescriptorish]struct{}{})
Expand Down
Loading

0 comments on commit 260eab9

Please sign in to comment.