Skip to content

Commit

Permalink
fix: fix encode proto well known types in form and url query (#1559)
Browse files Browse the repository at this point in the history
* fix encode proto well known types
  • Loading branch information
longXboy authored Oct 17, 2021
1 parent 014778b commit 210e414
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 96 deletions.
41 changes: 13 additions & 28 deletions encoding/form/proto_encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@ import (
"reflect"
"strconv"
"strings"
"time"

"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)

// EncodeMap encode proto message to url query.
Expand Down Expand Up @@ -84,7 +80,7 @@ func encodeByField(u url.Values, path string, v protoreflect.Message) error {
return err
}
default:
value, err := encodeField(fd, v.Get(fd))
value, err := EncodeField(fd, v.Get(fd))
if err != nil {
return err
}
Expand All @@ -98,7 +94,7 @@ func encodeByField(u url.Values, path string, v protoreflect.Message) error {
func encodeRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List) ([]string, error) {
var values []string
for i := 0; i < list.Len(); i++ {
value, err := encodeField(fieldDescriptor, list.Get(i))
value, err := EncodeField(fieldDescriptor, list.Get(i))
if err != nil {
return nil, err
}
Expand All @@ -111,11 +107,11 @@ func encodeRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list prot
func encodeMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map) (map[string]string, error) {
m := make(map[string]string)
mp.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
key, err := encodeField(fieldDescriptor.MapValue(), k.Value())
key, err := EncodeField(fieldDescriptor.MapValue(), k.Value())
if err != nil {
return false
}
value, err := encodeField(fieldDescriptor.MapValue(), v)
value, err := EncodeField(fieldDescriptor.MapValue(), v)
if err != nil {
return false
}
Expand All @@ -126,7 +122,8 @@ func encodeMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflec
return m, nil
}

func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) {
// EncodeField encode proto message filed
func EncodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) {
switch fieldDescriptor.Kind() {
case protoreflect.BoolKind:
return strconv.FormatBool(value.Bool()), nil
Expand All @@ -147,29 +144,17 @@ func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflec
}
}

// marshalMessage marshals the fields in the given protoreflect.Message.
// encodeMessage marshals the fields in the given protoreflect.Message.
// If the typeURL is non-empty, then a synthetic "@type" field is injected
// containing the URL as the value.
func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) {
switch msgDescriptor.FullName() {
case "google.protobuf.Timestamp":
t, ok := value.Interface().(*timestamppb.Timestamp)
if !ok {
return "", nil
}
return t.AsTime().Format(time.RFC3339Nano), nil
case "google.protobuf.Duration":
d, ok := value.Interface().(*durationpb.Duration)
if !ok {
return "", nil
}
return d.AsDuration().String(), nil
case "google.protobuf.BytesValue":
b, ok := value.Interface().(*wrapperspb.BytesValue)
if !ok {
return "", nil
}
return base64.StdEncoding.EncodeToString(b.Value), nil
case timestampMessageFullname:
return marshalTimestamp(value.Message())
case durationMessageFullname:
return marshalDuration(value.Message())
case bytesMessageFullname:
return marshalBytes(value.Message())
case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value",
"google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue":
fd := msgDescriptor.Fields()
Expand Down
88 changes: 88 additions & 0 deletions encoding/form/well_known_types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package form

import (
"encoding/base64"
"fmt"
"math"
"strings"
"time"

"google.golang.org/protobuf/reflect/protoreflect"
)

const (
// timestamp
timestampMessageFullname protoreflect.FullName = "google.protobuf.Timestamp"
maxTimestampSeconds = 253402300799
minTimestampSeconds = -6213559680013
timestampSecondsFieldNumber protoreflect.FieldNumber = 1
timestampNanosFieldNumber protoreflect.FieldNumber = 2

// duration
durationMessageFullname protoreflect.FullName = "google.protobuf.Duration"
secondsInNanos = 999999999
durationSecondsFieldNumber protoreflect.FieldNumber = 1
durationNanosFieldNumber protoreflect.FieldNumber = 2

// bytes
bytesMessageFullname protoreflect.FullName = "google.protobuf.BytesValue"
bytesValueFieldNumber protoreflect.FieldNumber = 1
)

func marshalTimestamp(m protoreflect.Message) (string, error) {
fds := m.Descriptor().Fields()
fdSeconds := fds.ByNumber(timestampSecondsFieldNumber)
fdNanos := fds.ByNumber(timestampNanosFieldNumber)

secsVal := m.Get(fdSeconds)
nanosVal := m.Get(fdNanos)
secs := secsVal.Int()
nanos := nanosVal.Int()
if secs < minTimestampSeconds || secs > maxTimestampSeconds {
return "", fmt.Errorf("%s: seconds out of range %v", timestampMessageFullname, secs)
}
if nanos < 0 || nanos > secondsInNanos {
return "", fmt.Errorf("%s: nanos out of range %v", timestampMessageFullname, nanos)
}
// Uses RFC 3339, where generated output will be Z-normalized and uses 0, 3,
// 6 or 9 fractional digits.
t := time.Unix(secs, nanos).UTC()
x := t.Format("2006-01-02T15:04:05.000000000")
x = strings.TrimSuffix(x, "000")
x = strings.TrimSuffix(x, "000")
x = strings.TrimSuffix(x, ".000")
return x + "Z", nil
}

func marshalDuration(m protoreflect.Message) (string, error) {
fds := m.Descriptor().Fields()
fdSeconds := fds.ByNumber(durationSecondsFieldNumber)
fdNanos := fds.ByNumber(durationNanosFieldNumber)

secsVal := m.Get(fdSeconds)
nanosVal := m.Get(fdNanos)
secs := secsVal.Int()
nanos := nanosVal.Int()
d := time.Duration(secs) * time.Second
overflow := d/time.Second != time.Duration(secs)
d += time.Duration(nanos) * time.Nanosecond
overflow = overflow || (secs < 0 && nanos < 0 && d > 0)
overflow = overflow || (secs > 0 && nanos > 0 && d < 0)
if overflow {
switch {
case secs < 0:
return time.Duration(math.MinInt64).String(), nil
case secs > 0:
return time.Duration(math.MaxInt64).String(), nil
}
}
return d.String(), nil
}

func marshalBytes(m protoreflect.Message) (string, error) {
fds := m.Descriptor().Fields()
fdBytes := fds.ByNumber(bytesValueFieldNumber)
bytesVal := m.Get(fdBytes)
val := bytesVal.Bytes()
return base64.StdEncoding.EncodeToString(val), nil
}
69 changes: 1 addition & 68 deletions transport/http/binding/encode.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
package binding

import (
"encoding/base64"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"

"github.com/go-kratos/kratos/v2/encoding/form"

"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)

// EncodeURL encode proto message to url path.
Expand Down Expand Up @@ -74,65 +67,5 @@ func getValueByField(v protoreflect.Message, fieldPath []string) (string, error)
}
v = v.Get(fd).Message()
}
return encodeField(fd, v.Get(fd))
}

func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) {
switch fieldDescriptor.Kind() {
case protoreflect.BoolKind:
return strconv.FormatBool(value.Bool()), nil
case protoreflect.EnumKind:
if fieldDescriptor.Enum().FullName() == "google.protobuf.NullValue" {
return "null", nil
}
desc := fieldDescriptor.Enum().Values().ByNumber(value.Enum())
return string(desc.Name()), nil
case protoreflect.StringKind:
return value.String(), nil
case protoreflect.BytesKind:
return base64.URLEncoding.EncodeToString(value.Bytes()), nil
case protoreflect.MessageKind, protoreflect.GroupKind:
return encodeMessage(fieldDescriptor.Message(), value)
default:
return fmt.Sprintf("%v", value.Interface()), nil
}
}

// encodeMessage marshals the fields in the given protoreflect.Message.
// If the typeURL is non-empty, then a synthetic "@type" field is injected
// containing the URL as the value.
func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) {
switch msgDescriptor.FullName() {
case "google.protobuf.Timestamp":
t, ok := value.Interface().(*timestamppb.Timestamp)
if !ok {
return "", nil
}
return t.AsTime().Format(time.RFC3339Nano), nil
case "google.protobuf.Duration":
d, ok := value.Interface().(*durationpb.Duration)
if !ok {
return "", nil
}
return d.AsDuration().String(), nil
case "google.protobuf.BytesValue":
b, ok := value.Interface().(*wrapperspb.BytesValue)
if !ok {
return "", nil
}
return base64.StdEncoding.EncodeToString(b.Value), nil
case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value",
"google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue":
fd := msgDescriptor.Fields()
v := value.Message().Get(fd.ByName(protoreflect.Name("value"))).Message()
return fmt.Sprintf("%v", v.Interface()), nil
case "google.protobuf.FieldMask":
m, ok := value.Interface().(*field_mask.FieldMask)
if !ok {
return "", nil
}
return strings.Join(m.Paths, ","), nil
default:
return "", fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
}
return form.EncodeField(fd, v.Get(fd))
}

0 comments on commit 210e414

Please sign in to comment.