Skip to content

Commit

Permalink
add go_struct_tags as a field
Browse files Browse the repository at this point in the history
  • Loading branch information
borosr committed Jan 9, 2023
1 parent 575603a commit 1fcded3
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 20 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ go_type = 1001
go_import = 1002
go_import_alias = 1003
go_zero_override = 1004
go_struct_tags = 1005
Because `protoc` can't process the extended options, so we can't find the by name, just by place.
69 changes: 50 additions & 19 deletions cmd/protoc-gen-go/internal_gengo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"go/token"
"math"
"os"
"regexp"
"strconv"
"strings"
"unicode"
Expand Down Expand Up @@ -40,6 +41,7 @@ type overrideParams struct {
goImport string
goImportAlias string
goZeroOverride string
goStructTags string
}

// overrideFields stores all the found messages which are created to override types
Expand All @@ -52,6 +54,8 @@ var SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPT
var GenerateProtobufSpecific = true
var TypeOverride = false

var tagRegex = regexp.MustCompile(`([^=;]+)=([^;]+);?`)

var lookupExtensionNames = []string{"go_type", "go_import", "go_import_alias"}

// Standard library dependencies.
Expand Down Expand Up @@ -202,6 +206,9 @@ func processExtensions(field *protogen.Field) (overrideParams, bool) {
log.Log("processExtensions:: zero override found!")
override.goZeroOverride = ex.Value().String()
ok = true
case impl.FieldOptionGoStructTags:
override.goStructTags = ex.Value().String()
ok = true
}
}
return override, ok
Expand All @@ -227,6 +234,9 @@ func processUninterpretedOptions(field *protogen.Field) (overrideParams, bool) {
case impl.FieldOptionGoZeroOverride:
override.goZeroOverride = string(o.GetStringValue())
ok = true
case impl.FieldOptionGoStructTags:
override.goStructTags = string(o.GetStringValue())
ok = true
}
}
}
Expand Down Expand Up @@ -531,13 +541,14 @@ func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, fie
if pointer {
goType = "*" + goType
}
if overrideParam, ok := goTypeOverride(m.GoIdent.GoName, field.GoName); ok {
if overrideParam, ok := getOverrideField(m.GoIdent.GoName, field.GoName); ok {
goType = overrideParam.goType
}
tags := structTags{
{"protobuf", fieldProtobufTagValue(field)},
{"json", fieldJSONTagValue(field)},
}
tags = append(tags, getAdditionalTags(m, field)...)

if field.Desc.IsMap() {
key := field.Message.Fields[0]
val := field.Message.Fields[1]
Expand All @@ -563,6 +574,41 @@ func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, fie
sf.append(field.GoName)
}

func getOverrideField(messageName, fieldName string) (overrideParams, bool) {
if !TypeOverride {
return overrideParams{}, false
}
overrideMsg, ok := overrideFields[messageName]
if !ok {
return overrideParams{}, false
}
overrideField, ok := overrideMsg[fieldName]
return overrideField, ok
}

func getAdditionalTags(m *messageInfo, field *protogen.Field) [][2]string {
o, ok := getOverrideField(m.GoIdent.GoName, field.GoName)
if !ok || o.goStructTags == "" {
return [][2]string{
{"json", fieldJSONTagValue(field)},
}
}
return buildTags(o.goStructTags)
}

func buildTags(wrappedTags string) [][2]string {
var tags [][2]string
for _, token := range strings.Split(wrappedTags, ";") {
for _, submatch := range tagRegex.FindAllStringSubmatch(token, -1) {
if len(submatch) != 3 {
continue
}
tags = append(tags, [2]string{submatch[1], submatch[2]})
}
}
return tags
}

// genMessageDefaultDecls generates consts and vars holding the default
// values of fields.
func genMessageDefaultDecls(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
Expand All @@ -573,7 +619,7 @@ func genMessageDefaultDecls(g *protogen.GeneratedFile, f *fileInfo, m *messageIn
}
name := "Default_" + m.GoIdent.GoName + "_" + field.GoName
goType, _ := fieldGoType(g, f, field)
if overrideParam, ok := goTypeOverride(m.GoIdent.GoName, field.GoName); ok {
if overrideParam, ok := getOverrideField(m.GoIdent.GoName, field.GoName); ok {
goType = overrideParam.goType
}
defVal := field.Desc.Default()
Expand Down Expand Up @@ -698,7 +744,7 @@ func genMessageGetterMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageI

// Getter for message field.
goType, pointer := fieldGoType(g, f, field)
overrideParam, overwritten := goTypeOverride(m.GoIdent.GoName, field.GoName)
overrideParam, overwritten := getOverrideField(m.GoIdent.GoName, field.GoName)
if overwritten {
goType = overrideParam.goType
}
Expand Down Expand Up @@ -1068,18 +1114,3 @@ func (c trailingComment) String() string {
}
return s
}

func goTypeOverride(msgName string, fieldName string) (overrideParams, bool) {
if TypeOverride {
// TODO check the case when goType is a map
if oMsg, okMsg := overrideFields[msgName]; okMsg {
if o, okField := oMsg[fieldName]; okField {
return o, true
}
}
// if strings.Contains(goType, "RepeatedString") {
// return strings.ReplaceAll(goType, "RepeatedString", "[]string")
// }
}
return overrideParams{}, false
}
62 changes: 62 additions & 0 deletions cmd/protoc-gen-go/internal_gengo/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"testing"

"github.com/infiniteloopcloud/protoc-gen-go-types/compiler/protogen"
"github.com/infiniteloopcloud/protoc-gen-go-types/parser"
)

Expand All @@ -30,3 +31,64 @@ func TestGenerateFile(t *testing.T) {
}
}
}

func TestBuildTags(t *testing.T) {
tags := buildTags("json=something,omitempty;validate=date")
if len(tags) != 2 {
t.Fatal("invalid tag count, must be 2")
}
if tags[0][0] != "json" {
t.Errorf("Invalid tag name, should be 'json', instead of %q", tags[0][0])
}
if tags[0][1] != "something,omitempty" {
t.Errorf("Invalid tag name, should be 'something,omitempty', instead of %q", tags[0][1])
}
if tags[1][0] != "validate" {
t.Errorf("Invalid tag name, should be 'validate', instead of %q", tags[1][0])
}
if tags[1][1] != "date" {
t.Errorf("Invalid tag name, should be 'date', instead of %q", tags[1][1])
}
}

func TestGetAdditionalTags(t *testing.T) {
TypeOverride = true
overrideFields = map[string]map[string]overrideParams{
"TestStruct": {
"TestField": overrideParams{
goStructTags: `json=id_dont_know,omitempty;boil=donno;validate=true`,
},
},
}
tags := getAdditionalTags(&messageInfo{
Message: &protogen.Message{
GoIdent: protogen.GoIdent{
GoName: "TestStruct",
},
},
}, &protogen.Field{
GoName: "TestField",
})

if len(tags) != 3 {
t.Fatal("Tags length must be 3")
}
if tags[0][0] != "json" {
t.Errorf("First tag should be `json`, instead of %s", tags[0][0])
}
if tags[0][1] != "id_dont_know,omitempty" {
t.Errorf("First tag should be `id_dont_know,omitempty`, instead of %s", tags[0][1])
}
if tags[1][0] != "boil" {
t.Errorf("First tag should be `boil`, instead of %s", tags[1][0])
}
if tags[1][1] != "donno" {
t.Errorf("First tag should be `donno`, instead of %s", tags[1][1])
}
if tags[2][0] != "validate" {
t.Errorf("First tag should be `validate`, instead of %s", tags[2][0])
}
if tags[2][1] != "true" {
t.Errorf("First tag should be `true`, instead of %s", tags[2][1])
}
}
1 change: 1 addition & 0 deletions cmd/protoc-gen-go/internal_gengo/test_data/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ extend google.protobuf.FieldOptions {
optional string go_import = 1002;
optional string go_import_alias = 1003;
optional string go_zero_override = 1004;
optional string go_struct_tags = 1005;
}
2 changes: 1 addition & 1 deletion cmd/protoc-gen-go/internal_gengo/test_data/test.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ message Test {
map<uint64, RepeatedString> map_field = 2 [(go_type) = "map[uint64][]string", (go_import) = ""];
string test = 3;
string optStr = 4 [(go_type) = "null.String", (go_import) = "github.com/volatiletech/null/v9", (go_import_alias) = "null"];
int32 optInt = 5 [(go_type) = "null.Int32"];
int32 optInt = 5 [(go_type) = "null.Int32", (go_struct_tags) = "json=-;boil=hello"];
int32 optBigInt = 6 [(go_type) = "null.Int64", (go_import) = "github.com/volatiletech/null/v9", (go_import_alias) = "null"];
string something = 7 [(go_type) = "Something", (go_import) = "", (go_zero_override) = "\"\""];
}
4 changes: 4 additions & 0 deletions internal/impl/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ const (
FieldOptionGoImport = "go_import"
FieldOptionGoImportAlias = "go_import_alias"
FieldOptionGoZeroOverride = "go_zero_override"
FieldOptionGoStructTags = "go_struct_tags"

FieldOptionGoTypeNum = 1001
FieldOptionGoImportNum = 1002
FieldOptionGoImportAliasNum = 1003
FieldOptionGoZeroOverrideNum = 1004
FieldOptionGoStructTagsNum = 1005
)

var errDecode = errors.New("cannot parse invalid wire-format data")
Expand Down Expand Up @@ -285,6 +287,8 @@ func (mi *MessageInfo) fallbackCreateExtension(num protowire.Number) (protorefle
name = FieldOptionGoImportAlias
case FieldOptionGoZeroOverrideNum:
name = FieldOptionGoZeroOverride
case FieldOptionGoStructTagsNum:
name = FieldOptionGoStructTags
default:
return nil, errors.New("invalid name")
}
Expand Down

0 comments on commit 1fcded3

Please sign in to comment.