Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

protoparse: fix extension resolution in custom options #484

Merged
merged 1 commit into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 101 additions & 36 deletions desc/protoparse/linker.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (l *linker) resolveReferences() error {
prefix += "."
}
if fd.Options != nil {
if err := l.resolveOptions(r, fd, "file", fd.GetName(), proto.MessageName(fd.Options), fd.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "file", fd.GetName(), fd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand Down Expand Up @@ -342,14 +342,14 @@ func (l *linker) resolveReferences() error {
func (l *linker) resolveEnumTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, ed *dpb.EnumDescriptorProto, scopes []scope) error {
enumFqn := prefix + ed.GetName()
if ed.Options != nil {
if err := l.resolveOptions(r, fd, "enum", enumFqn, proto.MessageName(ed.Options), ed.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "enum", enumFqn, ed.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
for _, evd := range ed.Value {
if evd.Options != nil {
evFqn := enumFqn + "." + evd.GetName()
if err := l.resolveOptions(r, fd, "enum value", evFqn, proto.MessageName(evd.Options), evd.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "enum value", evFqn, evd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand All @@ -359,16 +359,21 @@ func (l *linker) resolveEnumTypes(r *parseResult, fd *dpb.FileDescriptorProto, p

func (l *linker) resolveMessageTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, md *dpb.DescriptorProto, scopes []scope) error {
fqn := prefix + md.GetName()
scope := messageScope(fqn, isProto3(fd), l, fd)
scopes = append(scopes, scope)
prefix = fqn + "."

// Strangely, when protoc resolves extension names, it uses the *enclosing* scope
// instead of the message's scope. So if the message contains an extension named "i",
// an option cannot refer to it as simply "i" but must qualify it (at a minimum "Msg.i").
// So we don't add this messages scope to our scopes slice until *after* we do options.
if md.Options != nil {
if err := l.resolveOptions(r, fd, "message", fqn, proto.MessageName(md.Options), md.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "message", fqn, md.Options.UninterpretedOption, scopes); err != nil {
return err
}
}

scope := messageScope(fqn, isProto3(fd), l, fd)
scopes = append(scopes, scope)
prefix = fqn + "."

for _, nmd := range md.NestedType {
if err := l.resolveMessageTypes(r, fd, prefix, nmd, scopes); err != nil {
return err
Expand All @@ -387,7 +392,7 @@ func (l *linker) resolveMessageTypes(r *parseResult, fd *dpb.FileDescriptorProto
for _, ood := range md.OneofDecl {
if ood.Options != nil {
ooName := fmt.Sprintf("%s.%s", fqn, ood.GetName())
if err := l.resolveOptions(r, fd, "oneof", ooName, proto.MessageName(ood.Options), ood.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "oneof", ooName, ood.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand All @@ -400,7 +405,7 @@ func (l *linker) resolveMessageTypes(r *parseResult, fd *dpb.FileDescriptorProto
for _, er := range md.ExtensionRange {
if er.Options != nil {
erName := fmt.Sprintf("%s:%d-%d", fqn, er.GetStart(), er.GetEnd()-1)
if err := l.resolveOptions(r, fd, "extension range", erName, proto.MessageName(er.Options), er.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "extension range", erName, er.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand Down Expand Up @@ -459,7 +464,7 @@ func (l *linker) resolveFieldTypes(r *parseResult, fd *dpb.FileDescriptorProto,
}

if fld.Options != nil {
if err := l.resolveOptions(r, fd, elemType, thisName, proto.MessageName(fld.Options), fld.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, elemType, thisName, fld.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand Down Expand Up @@ -501,7 +506,7 @@ func (l *linker) resolveFieldTypes(r *parseResult, fd *dpb.FileDescriptorProto,
func (l *linker) resolveServiceTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, sd *dpb.ServiceDescriptorProto, scopes []scope) error {
svcFqn := prefix + sd.GetName()
if sd.Options != nil {
if err := l.resolveOptions(r, fd, "service", svcFqn, proto.MessageName(sd.Options), sd.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "service", svcFqn, sd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand All @@ -512,7 +517,7 @@ func (l *linker) resolveServiceTypes(r *parseResult, fd *dpb.FileDescriptorProto

for _, mtd := range sd.Method {
if mtd.Options != nil {
if err := l.resolveOptions(r, fd, "method", svcFqn+"."+mtd.GetName(), proto.MessageName(mtd.Options), mtd.Options.UninterpretedOption, scopes); err != nil {
if err := l.resolveOptions(r, fd, "method", svcFqn+"."+mtd.GetName(), mtd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
Expand Down Expand Up @@ -558,48 +563,108 @@ func (l *linker) resolveServiceTypes(r *parseResult, fd *dpb.FileDescriptorProto
return nil
}

func (l *linker) resolveOptions(r *parseResult, fd *dpb.FileDescriptorProto, elemType, elemName, optType string, opts []*dpb.UninterpretedOption, scopes []scope) error {
var scope string
if elemType != "file" {
scope = fmt.Sprintf("%s %s: ", elemType, elemName)
func (l *linker) resolveOptions(r *parseResult, fd *dpb.FileDescriptorProto, elemType, elemName string, opts []*dpb.UninterpretedOption, scopes []scope) error {
mc := &messageContext{
res: r,
elementName: elemName,
elementType: elemType,
}
opts:
for _, opt := range opts {
// resolve any extension names found in option names
for _, nm := range opt.Name {
if nm.GetIsExtension() {
node := r.getOptionNamePartNode(nm)
fqn, dsc, _ := l.resolve(fd, nm.GetNamePart(), false, scopes)
if dsc == nil {
if err := l.errs.handleErrorWithPos(node.Start(), "%sunknown extension %s", scope, nm.GetNamePart()); err != nil {
return err
}
continue opts
}
if dsc == sentinelMissingSymbol {
if err := l.errs.handleErrorWithPos(node.Start(), "%sunknown extension %s; resolved to %s which is not defined; consider using a leading dot", scope, nm.GetNamePart(), fqn); err != nil {
fqn, err := l.resolveExtensionName(nm.GetNamePart(), fd, scopes)
if err != nil {
node := r.getOptionNamePartNode(nm)
if err := l.errs.handleErrorWithPos(node.Start(), "%v%v", mc, err); err != nil {
return err
}
continue opts
}
if ext, ok := dsc.(*dpb.FieldDescriptorProto); !ok {
otherType := descriptorType(dsc)
if err := l.errs.handleErrorWithPos(node.Start(), "%sinvalid extension: %s is a %s, not an extension", scope, nm.GetNamePart(), otherType); err != nil {
return err
}
continue opts
} else if ext.GetExtendee() == "" {
if err := l.errs.handleErrorWithPos(node.Start(), "%sinvalid extension: %s is a field but not an extension", scope, nm.GetNamePart()); err != nil {
nm.NamePart = proto.String(fqn)
}
}
// also resolve any extension names found inside message literals in option values
mc.option = opt
optVal := r.getOptionNode(opt).GetValue()
if err := l.resolveOptionValue(r, mc, fd, optVal, scopes); err != nil {
return err
}
mc.option = nil
}
return nil
}

func (l *linker) resolveOptionValue(r *parseResult, mc *messageContext, fd *dpb.FileDescriptorProto, val ast.ValueNode, scopes []scope) error {
optVal := val.Value()
switch optVal := optVal.(type) {
case []ast.ValueNode:
origPath := mc.optAggPath
defer func() {
mc.optAggPath = origPath
}()
for i, v := range optVal {
mc.optAggPath = fmt.Sprintf("%s[%d]", origPath, i)
if err := l.resolveOptionValue(r, mc, fd, v, scopes); err != nil {
return err
}
}
case []*ast.MessageFieldNode:
origPath := mc.optAggPath
defer func() {
mc.optAggPath = origPath
}()
for _, fld := range optVal {
// check for extension name
if fld.Name.IsExtension() {
fqn, err := l.resolveExtensionName(string(fld.Name.Name.AsIdentifier()), fd, scopes)
if err != nil {
if err := l.errs.handleErrorWithPos(fld.Name.Name.Start(), "%v%v", mc, err); err != nil {
return err
}
continue opts
} else {
r.optionQualifiedNames[fld.Name.Name] = fqn
}
nm.NamePart = proto.String("." + fqn)
}

// recurse into value
mc.optAggPath = origPath
if origPath != "" {
mc.optAggPath += "."
}
if fld.Name.IsExtension() {
mc.optAggPath = fmt.Sprintf("%s[%s]", mc.optAggPath, string(fld.Name.Name.AsIdentifier()))
} else {
mc.optAggPath = fmt.Sprintf("%s%s", mc.optAggPath, string(fld.Name.Name.AsIdentifier()))
}

if err := l.resolveOptionValue(r, mc, fd, fld.Val, scopes); err != nil {
return err
}
}
}

return nil
}

func (l *linker) resolveExtensionName(name string, fd *dpb.FileDescriptorProto, scopes []scope) (string, error) {
fqn, dsc, _ := l.resolve(fd, name, false, scopes)
if dsc == nil {
return "", fmt.Errorf("unknown extension %s", name)
}
if dsc == sentinelMissingSymbol {
return "", fmt.Errorf("unknown extension %s; resolved to %s which is not defined; consider using a leading dot", name, fqn)
}
if ext, ok := dsc.(*dpb.FieldDescriptorProto); !ok {
otherType := descriptorType(dsc)
return "", fmt.Errorf("invalid extension: %s is a %s, not an extension", name, otherType)
} else if ext.GetExtendee() == "" {
return "", fmt.Errorf("invalid extension: %s is a field but not an extension", name)
}
return "." + fqn, nil
}

func (l *linker) resolve(fd *dpb.FileDescriptorProto, name string, onlyTypes bool, scopes []scope) (fqn string, element proto.Message, proto3 bool) {
if strings.HasPrefix(name, ".") {
// already fully-qualified
Expand Down
121 changes: 119 additions & 2 deletions desc/protoparse/linker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ func TestLinkerValidation(t *testing.T) {
"extend google.protobuf.MessageOptions { optional Foo foo = 10001; }\n" +
"message Baz { option (foo) = { [Bar]< name: \"abc\" > }; }\n",
},
"foo.proto:6:30: message Baz: option (foo): field Bar not found",
"foo.proto:6:33: message Baz: option (foo): invalid extension: Bar is a message, not an extension",
},
{
map[string]string{
Expand Down Expand Up @@ -559,13 +559,130 @@ func TestLinkerValidation(t *testing.T) {
},
{
map[string]string{
"a.proto": "syntax=\"proto3\";\nmessage m{\n" +
"a.proto": "syntax=\"proto3\";\n" +
"message m{\n" +
" string z = 1;\n" +
" oneof z{int64 b=2;}\n" +
"}",
},
`a.proto:3:3: duplicate symbol m.z: already defined as oneof`,
},
{
map[string]string{
"test.proto": "syntax=\"proto2\";\n" +
"package foo.bar;\n" +
"import \"google/protobuf/descriptor.proto\";\n" +
"message a { extensions 1 to 100; }\n" +
"extend google.protobuf.MessageOptions { optional a msga = 10000; }\n" +
"message b {\n" +
" message c {\n" +
" extend a { repeated int32 i = 1; repeated float f = 2; }\n" +
" }\n" +
" option (msga) = {\n" +
" [foo.bar.b.c.i]: 123\n" +
" [bar.b.c.i]: 234\n" +
" [b.c.i]: 345\n" +
" };\n" +
" option (msga).(foo.bar.b.c.f) = 1.23;\n" +
" option (msga).(bar.b.c.f) = 2.34;\n" +
" option (msga).(b.c.f) = 3.45;\n" +
"}",
},
"", // should succeed
},
{
map[string]string{
"test.proto": "syntax=\"proto2\";\n" +
"package foo.bar;\n" +
"import \"google/protobuf/descriptor.proto\";\n" +
"message a { extensions 1 to 100; }\n" +
"message b { extensions 1 to 100; }\n" +
"extend google.protobuf.MessageOptions { optional a msga = 10000; }\n" +
"message c {\n" +
" extend a { optional b b = 1; }\n" +
" extend b { repeated int32 i = 1; repeated float f = 2; }\n" +
" option (msga) = {\n" +
" [foo.bar.c.b] {\n" +
" [foo.bar.c.i]: 123\n" +
" [bar.c.i]: 234\n" +
" [c.i]: 345\n" +
" }\n" +
" };\n" +
" option (msga).(foo.bar.c.b).(foo.bar.c.f) = 1.23;\n" +
" option (msga).(foo.bar.c.b).(bar.c.f) = 2.34;\n" +
" option (msga).(foo.bar.c.b).(c.f) = 3.45;\n" +
"}",
},
"", // should succeed
},
{
map[string]string{
"test.proto": "syntax=\"proto2\";\n" +
"package foo.bar;\n" +
"import \"google/protobuf/descriptor.proto\";\n" +
"message a { extensions 1 to 100; }\n" +
"extend google.protobuf.MessageOptions { optional a msga = 10000; }\n" +
"message b {\n" +
" message c {\n" +
" extend a { repeated int32 i = 1; repeated float f = 2; }\n" +
" }\n" +
" option (msga) = {\n" +
" [c.i]: 456\n" +
" };\n" +
"}",
},
"test.proto:11:6: message foo.bar.b: option (foo.bar.msga): unknown extension c.i",
},
{
map[string]string{
"test.proto": "syntax=\"proto2\";\n" +
"package foo.bar;\n" +
"import \"google/protobuf/descriptor.proto\";\n" +
"message a { extensions 1 to 100; }\n" +
"extend google.protobuf.MessageOptions { optional a msga = 10000; }\n" +
"message b {\n" +
" message c {\n" +
" extend a { repeated int32 i = 1; repeated float f = 2; }\n" +
" }\n" +
" option (msga) = {\n" +
" [i]: 567\n" +
" };\n" +
"}",
},
"test.proto:11:6: message foo.bar.b: option (foo.bar.msga): unknown extension i",
},
{
map[string]string{
"test.proto": "syntax=\"proto2\";\n" +
"package foo.bar;\n" +
"import \"google/protobuf/descriptor.proto\";\n" +
"message a { extensions 1 to 100; }\n" +
"extend google.protobuf.MessageOptions { optional a msga = 10000; }\n" +
"message b {\n" +
" message c {\n" +
" extend a { repeated int32 i = 1; repeated float f = 2; }\n" +
" }\n" +
" option (msga).(c.f) = 4.56;\n" +
"}",
},
"test.proto:10:17: message foo.bar.b: unknown extension c.f",
},
{
map[string]string{
"test.proto": "syntax=\"proto2\";\n" +
"package foo.bar;\n" +
"import \"google/protobuf/descriptor.proto\";\n" +
"message a { extensions 1 to 100; }\n" +
"extend google.protobuf.MessageOptions { optional a msga = 10000; }\n" +
"message b {\n" +
" message c {\n" +
" extend a { repeated int32 i = 1; repeated float f = 2; }\n" +
" }\n" +
" option (msga).(f) = 5.67;\n" +
"}",
},
"test.proto:10:17: message foo.bar.b: unknown extension f",
},
{
map[string]string{
"a.proto": "syntax=\"proto3\";\nmessage m{\n" +
Expand Down
7 changes: 6 additions & 1 deletion desc/protoparse/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -1348,10 +1348,15 @@ func fieldValue(res *parseResult, mc *messageContext, fld fldDescriptorish, val
}
var ffld *desc.FieldDescriptor
if a.Name.IsExtension() {
n := string(a.Name.Name.AsIdentifier())
n := res.optionQualifiedNames[a.Name.Name]
if n == "" {
// this should not be possible!
n = string(a.Name.Name.AsIdentifier())
}
ffld = findExtension(mc.file, n, false, map[fileDescriptorish]struct{}{})
if ffld == nil {
// may need to qualify with package name
// (this should not be necessary!)
pkg := mc.file.GetPackage()
if pkg != "" {
ffld = findExtension(mc.file, pkg+"."+n, false, map[fileDescriptorish]struct{}{})
Expand Down
Loading