Skip to content

Commit

Permalink
Fix panic that can occur when interpreting options in lenient mode (#331
Browse files Browse the repository at this point in the history
)

The panic fix is tiny. But this commit is bigger because other
changes/improvements were called for: most of the code changes are to
improve error handling when in lenient mode. After I fixed the panic,
the test case was failing in a different way, due to an issue with how
errors (even in lenient mode) were still being passed to an error
reporter and causing the stage to fail.

This issue was introduced in #279. That PR moved things around, pushing
the responsibility of calling `interp.reporter.HandleError` down, so
that interpreting a single option could potentially report multiple
errors (instead of failing fast and only reporting the first).

But when in lenient mode, we don't actually want to send those errors to
the reporter: sending the error to the reporter meant the error would
get memoized and then returned in subsequent handle calls, which could
cause the process to fail when it should be lenient and also can cause
it to report the wrong error in lenient mode. So now we demarcate the
parts of the process where errors are tolerated (i.e. where lenience
mode is actually activated) with a new field on the interpreter that is
examined when errors are reported.
  • Loading branch information
jhump authored Aug 20, 2024
1 parent 0de629a commit 0142a07
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 96 deletions.
11 changes: 6 additions & 5 deletions internal/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,24 @@ import (
"google.golang.org/protobuf/types/descriptorpb"

"github.com/bufbuild/protocompile/ast"
"github.com/bufbuild/protocompile/reporter"
)

type hasOptionNode interface {
OptionNode(part *descriptorpb.UninterpretedOption) ast.OptionDeclNode
FileNode() ast.FileDeclNode // needed in order to query for NodeInfo
}

func FindFirstOption(res hasOptionNode, handler *reporter.Handler, scope string, opts []*descriptorpb.UninterpretedOption, name string) (int, error) {
type errorHandler func(span ast.SourceSpan, format string, args ...interface{}) error

func FindFirstOption(res hasOptionNode, handler errorHandler, scope string, opts []*descriptorpb.UninterpretedOption, name string) (int, error) {
return findOption(res, handler, scope, opts, name, false, true)
}

func FindOption(res hasOptionNode, handler *reporter.Handler, scope string, opts []*descriptorpb.UninterpretedOption, name string) (int, error) {
func FindOption(res hasOptionNode, handler errorHandler, scope string, opts []*descriptorpb.UninterpretedOption, name string) (int, error) {
return findOption(res, handler, scope, opts, name, true, false)
}

func findOption(res hasOptionNode, handler *reporter.Handler, scope string, opts []*descriptorpb.UninterpretedOption, name string, exact, first bool) (int, error) {
func findOption(res hasOptionNode, handler errorHandler, scope string, opts []*descriptorpb.UninterpretedOption, name string, exact, first bool) (int, error) {
found := -1
for i, opt := range opts {
if exact && len(opt.Name) != 1 {
Expand All @@ -51,7 +52,7 @@ func findOption(res hasOptionNode, handler *reporter.Handler, scope string, opts
fn := res.FileNode()
node := optNode.GetName()
nodeInfo := fn.NodeInfo(node)
return -1, handler.HandleErrorf(nodeInfo, "%s: option %s cannot be defined more than once", scope, name)
return -1, handler(nodeInfo, "%s: option %s cannot be defined more than once", scope, name)
}
found = i
}
Expand Down
178 changes: 114 additions & 64 deletions options/options.go

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions options/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ func TestOptionsInUnlinkedFiles(t *testing.T) {
assert.Equal(t, "foo.bar", fd.GetOptions().GetGoPackage())
},
},
{
name: "file options, not custom",
contents: `option go_package = "foo.bar"; option must_link = "FOO";`,
uninterpreted: map[string]interface{}{
"test.proto:must_link": "FOO",
},
checkInterpreted: func(t *testing.T, fd *descriptorpb.FileDescriptorProto) {
assert.Equal(t, "foo.bar", fd.GetOptions().GetGoPackage())
},
},
{
name: "message options",
contents: `message Test { option (must.link) = 1.234; option deprecated = true; }`,
Expand Down Expand Up @@ -244,6 +254,25 @@ func TestOptionsInUnlinkedFiles(t *testing.T) {
}
}

func TestOptionsInUnlinkedFileInvalid(t *testing.T) {
t.Parallel()
h := reporter.NewHandler(nil)
ast, err := parser.Parse(
"test.proto",
strings.NewReader(
`syntax = "proto2";
package foo;
option malformed_non_existent = true;
option features.utf8_validation = NONE;`,
), h)
require.NoError(t, err, "failed to parse")
res, err := parser.ResultFromAST(ast, false, h)
require.NoError(t, err, "failed to produce descriptor proto")
_, err = options.InterpretUnlinkedOptions(res)
require.ErrorContains(t, err,
`test.proto:4:29: field "google.protobuf.FeatureSet.utf8_validation" was not introduced until edition 2023`)
}

func buildUninterpretedMapForFile(fd *descriptorpb.FileDescriptorProto, opts map[string]interface{}) {
buildUninterpretedMap(fd.GetName(), fd.GetOptions().GetUninterpretedOption(), opts)
for _, md := range fd.GetMessageType() {
Expand Down
2 changes: 1 addition & 1 deletion parser/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ func (r *result) addMessageBody(msgd *descriptorpb.DescriptorProto, body *ast.Me

func (r *result) isMessageSetWireFormat(scope string, md *descriptorpb.DescriptorProto, handler *reporter.Handler) (*descriptorpb.UninterpretedOption, error) {
uo := md.GetOptions().GetUninterpretedOption()
index, err := internal.FindOption(r, handler, scope, uo, "message_set_wire_format")
index, err := internal.FindOption(r, handler.HandleErrorf, scope, uo, "message_set_wire_format")
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions parser/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func validateNoFeatures(res *result, syntax protoreflect.Syntax, scope string, o
// Editions is allowed to use features
return nil
}
if index, err := internal.FindFirstOption(res, handler, scope, opts, "features"); err != nil {
if index, err := internal.FindFirstOption(res, handler.HandleErrorf, scope, opts, "features"); err != nil {
return err
} else if index >= 0 {
optNode := res.OptionNode(opts[index])
Expand All @@ -135,7 +135,7 @@ func validateMessage(res *result, syntax protoreflect.Syntax, name protoreflect.
}
}

if index, err := internal.FindOption(res, handler, scope, md.Options.GetUninterpretedOption(), "map_entry"); err != nil {
if index, err := internal.FindOption(res, handler.HandleErrorf, scope, md.Options.GetUninterpretedOption(), "map_entry"); err != nil {
return err
} else if index >= 0 {
optNode := res.OptionNode(md.Options.GetUninterpretedOption()[index])
Expand Down Expand Up @@ -331,7 +331,7 @@ func validateEnum(res *result, syntax protoreflect.Syntax, name protoreflect.Ful

allowAlias := false
var allowAliasOpt *descriptorpb.UninterpretedOption
if index, err := internal.FindOption(res, handler, scope, ed.Options.GetUninterpretedOption(), "allow_alias"); err != nil {
if index, err := internal.FindOption(res, handler.HandleErrorf, scope, ed.Options.GetUninterpretedOption(), "allow_alias"); err != nil {
return err
} else if index >= 0 {
allowAliasOpt = ed.Options.UninterpretedOption[index]
Expand Down Expand Up @@ -481,7 +481,7 @@ func validateField(res *result, syntax protoreflect.Syntax, name protoreflect.Fu
return err
}
}
if index, err := internal.FindOption(res, handler, scope, fld.Options.GetUninterpretedOption(), "packed"); err != nil {
if index, err := internal.FindOption(res, handler.HandleErrorf, scope, fld.Options.GetUninterpretedOption(), "packed"); err != nil {
return err
} else if index >= 0 {
optNode := res.OptionNode(fld.Options.GetUninterpretedOption()[index])
Expand All @@ -491,7 +491,7 @@ func validateField(res *result, syntax protoreflect.Syntax, name protoreflect.Fu
}
}
} else if syntax == protoreflect.Proto3 {
if index, err := internal.FindOption(res, handler, scope, fld.Options.GetUninterpretedOption(), "default"); err != nil {
if index, err := internal.FindOption(res, handler.HandleErrorf, scope, fld.Options.GetUninterpretedOption(), "default"); err != nil {
return err
} else if index >= 0 {
optNode := res.OptionNode(fld.Options.GetUninterpretedOption()[index])
Expand Down
17 changes: 11 additions & 6 deletions reporter/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,36 @@ type ErrorWithPos interface {

// Error creates a new ErrorWithPos from the given error and source position.
func Error(span ast.SourceSpan, err error) ErrorWithPos {
return errorWithSpan{SourceSpan: span, underlying: err}
var ewp ErrorWithPos
if errors.As(err, &ewp) {
// replace existing position with given one
return &errorWithSpan{SourceSpan: span, underlying: ewp.Unwrap()}
}
return &errorWithSpan{SourceSpan: span, underlying: err}
}

// Errorf creates a new ErrorWithPos whose underlying error is created using the
// given message format and arguments (via fmt.Errorf).
func Errorf(span ast.SourceSpan, format string, args ...interface{}) ErrorWithPos {
return errorWithSpan{SourceSpan: span, underlying: fmt.Errorf(format, args...)}
return Error(span, fmt.Errorf(format, args...))
}

type errorWithSpan struct {
ast.SourceSpan
underlying error
}

func (e errorWithSpan) Error() string {
func (e *errorWithSpan) Error() string {
sourcePos := e.GetPosition()
return fmt.Sprintf("%s: %v", sourcePos, e.underlying)
}

func (e errorWithSpan) GetPosition() ast.SourcePos {
func (e *errorWithSpan) GetPosition() ast.SourcePos {
return e.Start()
}

func (e errorWithSpan) Unwrap() error {
func (e *errorWithSpan) Unwrap() error {
return e.underlying
}

var _ ErrorWithPos = errorWithSpan{}
var _ ErrorWithPos = (*errorWithSpan)(nil)
17 changes: 2 additions & 15 deletions reporter/reporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,7 @@ func (h *Handler) HandleError(err error) error {
// call to HandleError or HandleErrorf), that same error is returned and the
// given error is not reported.
func (h *Handler) HandleErrorWithPos(span ast.SourceSpan, err error) error {
if ewp, ok := err.(ErrorWithPos); ok {
// replace existing position with given one
err = errorWithSpan{SourceSpan: span, underlying: ewp.Unwrap()}
} else {
err = errorWithSpan{SourceSpan: span, underlying: err}
}
return h.HandleError(err)
return h.HandleError(Error(span, err))
}

// HandleErrorf handles an error with the given source position, creating the
Expand Down Expand Up @@ -191,14 +185,7 @@ func (h *Handler) HandleWarning(err ErrorWithPos) {
// HandleWarningWithPos handles a warning with the given source position. This will
// delegate to the handler's configured reporter.
func (h *Handler) HandleWarningWithPos(span ast.SourceSpan, err error) {
ewp, ok := err.(ErrorWithPos)
if ok {
// replace existing position with given one
ewp = errorWithSpan{SourceSpan: span, underlying: ewp.Unwrap()}
} else {
ewp = errorWithSpan{SourceSpan: span, underlying: err}
}
h.HandleWarning(ewp)
h.HandleWarning(Error(span, err))
}

// HandleWarningf handles a warning with the given source position, creating the
Expand Down

0 comments on commit 0142a07

Please sign in to comment.