diff --git a/resolver/resolver.go b/resolver/resolver.go index 1f1654e..b8356e0 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -20,12 +20,13 @@ import ( "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/runtime/protoimpl" ) const ( - newExtensionIndex = "1159" - previousExtensionIndex = "51071" + newExtensionIndex = "1159" // protovalidate versions >= v0.2.0 + previousExtensionIndex = "51071" // protovalidate versions < v0.2.0 ) // DefaultResolver resolves protovalidate constraints options from descriptors. @@ -34,29 +35,31 @@ type DefaultResolver struct{} // ResolveMessageConstraints returns the MessageConstraints option set for the // MessageDescriptor. func (r DefaultResolver) ResolveMessageConstraints(desc protoreflect.MessageDescriptor) *validate.MessageConstraints { - constraints := resolveExt[protoreflect.MessageDescriptor, *validate.MessageConstraints](desc, validate.E_Message) - if constraints == nil { - constraints = resolveDeprecatedIndex[protoreflect.MessageDescriptor, *validate.MessageConstraints](desc, validate.E_Message) - } - return constraints + return resolveConstraints[validate.MessageConstraints](desc, validate.E_Message) } // ResolveOneofConstraints returns the OneofConstraints option set for the // OneofDescriptor. func (r DefaultResolver) ResolveOneofConstraints(desc protoreflect.OneofDescriptor) *validate.OneofConstraints { - constraints := resolveExt[protoreflect.OneofDescriptor, *validate.OneofConstraints](desc, validate.E_Oneof) - if constraints == nil { - constraints = resolveDeprecatedIndex[protoreflect.OneofDescriptor, *validate.OneofConstraints](desc, validate.E_Oneof) - } - return constraints + return resolveConstraints[validate.OneofConstraints](desc, validate.E_Oneof) } // ResolveFieldConstraints returns the FieldConstraints option set for the // FieldDescriptor. func (r DefaultResolver) ResolveFieldConstraints(desc protoreflect.FieldDescriptor) *validate.FieldConstraints { - constraints := resolveExt[protoreflect.FieldDescriptor, *validate.FieldConstraints](desc, validate.E_Field) + return resolveConstraints[validate.FieldConstraints](desc, validate.E_Field) +} + +func resolveConstraints[C any, CP interface { + *C + proto.Message +}]( + desc protoreflect.Descriptor, + extType *protoimpl.ExtensionInfo, +) (constraints CP) { + constraints = resolveExt[CP](desc.Options(), extType) if constraints == nil { - constraints = resolveDeprecatedIndex[protoreflect.FieldDescriptor, *validate.FieldConstraints](desc, validate.E_Field) + constraints = resolveDeprecatedIndex[CP](desc.Options(), extType) } return constraints } @@ -66,13 +69,14 @@ func (r DefaultResolver) ResolveFieldConstraints(desc protoreflect.FieldDescript // circumstances, particularly in dynamic or runtime contexts, the underlying // extension value's type may be a dynamicpb.Message. In this case, we fall back // through a proto.[Un]Marshal cycle to get it into the concrete type we expect. -func resolveExt[D protoreflect.Descriptor, C proto.Message]( - desc D, +func resolveExt[C proto.Message]( + options proto.Message, extType protoreflect.ExtensionType, ) (constraints C) { num := extType.TypeDescriptor().Number() var msg proto.Message - proto.RangeExtensions(desc.Options(), func(typ protoreflect.ExtensionType, i interface{}) bool { + + proto.RangeExtensions(options, func(typ protoreflect.ExtensionType, i interface{}) bool { if num != typ.TypeDescriptor().Number() { return true } @@ -93,16 +97,29 @@ func resolveExt[D protoreflect.Descriptor, C proto.Message]( } // resolveDeprecatedIndex is a fallback for the deprecated extension index. -func resolveDeprecatedIndex[D protoreflect.Descriptor, C proto.Message]( - desc D, +func resolveDeprecatedIndex[C proto.Message]( + options proto.Message, ext *protoimpl.ExtensionInfo, ) C { - return resolveExt[D, C](desc, &protoimpl.ExtensionInfo{ + extInfo := &protoimpl.ExtensionInfo{ ExtendedType: ext.ExtendedType, ExtensionType: ext.ExtensionType, Field: 51071, Name: ext.Name, Tag: strings.Replace(ext.Tag, newExtensionIndex, previousExtensionIndex, 1), Filename: ext.Filename, - }) + } + + // detect and handle if there are unknown options + if unknown := options.ProtoReflect().GetUnknown(); len(unknown) > 0 { + opts := options.ProtoReflect().Type().New() + resolver := &protoregistry.Types{} + if err := resolver.RegisterExtension(extInfo); err == nil { + if err = (&proto.UnmarshalOptions{Resolver: resolver}).Unmarshal(unknown, opts.Interface()); err == nil { + options = opts.Interface() + } + } + } + + return resolveExt[C](options, extInfo) }