Skip to content

Commit

Permalink
Merge pull request #10 from infiniteloopcloud/9-go-types-override
Browse files Browse the repository at this point in the history
Go types override
  • Loading branch information
PumpkinSeed authored Nov 1, 2022
2 parents 42d4020 + 07565df commit 502726f
Show file tree
Hide file tree
Showing 11 changed files with 362 additions and 36 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,8 @@ message TimeTime {}

#### Important

Currently the generator not importing the `time` package automatically. Temporary solution can be: `goimports -w *.pb.go`.
Currently the overwritten FieldOptions (go_type, go_import, go_import_alias) must be paired with these numbers:
go_type = 1001
go_import = 1002
go_import_alias = 1003
Because `protoc` can't process the extended options, so we can't find the by name, just by place.
163 changes: 147 additions & 16 deletions cmd/protoc-gen-go/internal_gengo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import (
"github.com/infiniteloopcloud/protoc-gen-go-types/compiler/protogen"
"github.com/infiniteloopcloud/protoc-gen-go-types/internal/encoding/tag"
"github.com/infiniteloopcloud/protoc-gen-go-types/internal/genid"
"github.com/infiniteloopcloud/protoc-gen-go-types/internal/impl"
"github.com/infiniteloopcloud/protoc-gen-go-types/internal/version"
"github.com/infiniteloopcloud/protoc-gen-go-types/log"
"github.com/infiniteloopcloud/protoc-gen-go-types/reflect/protoreflect"
"github.com/infiniteloopcloud/protoc-gen-go-types/runtime/protoimpl"

Expand All @@ -33,13 +35,24 @@ const (
EnvTypeOverride = "TYPE_OVERRIDE"
)

type overrideParams struct {
goType string
goImport string
goImportAlias string
}

// overrideFields stores all the found messages which are created to override types
var overrideFields = make(map[string]map[string]overrideParams)

// SupportedFeatures reports the set of supported protobuf language features.
var SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)

// GenerateVersionMarkers specifies whether to generate version markers.
var GenerateProtobufSpecific = true
var TypeOverride = false

var lookupExtensionNames = []string{"go_type", "go_import", "go_import_alias"}

// Standard library dependencies.
const (
base64Package = protogen.GoImportPath("encoding/base64")
Expand Down Expand Up @@ -80,6 +93,8 @@ func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated
TypeOverride = true
}

log.Log("GenerateFile>> SkipProtobufSpecific:%v | TypeOverride:%v\n", GenerateProtobufSpecific, TypeOverride)

filename := file.GeneratedFilenamePrefix + ".pb.go"
g := gen.NewGeneratedFile(filename, file.GoImportPath)
f := newFileInfo(file)
Expand All @@ -103,6 +118,11 @@ func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated
g.P()
}

for _, message := range f.allMessages {
buildOverrides(message)
}
log.Log("overrides count: %d\t data: %#v", len(overrideFields), overrideFields)
genOverrideImports(g)
for i, imps := 0, f.Desc.Imports(); i < imps.Len(); i++ {
genImport(gen, g, f, imps.Get(i))
}
Expand All @@ -123,6 +143,90 @@ func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated
return g
}

func genOverrideImports(g *protogen.GeneratedFile) {
for _, overrideMessage := range overrideFields {
for _, o := range overrideMessage {
if o.goImport != "" {
g.QualifiedGoIdent(protogen.GoIdent{
GoImportAlias: protogen.GoPackageName(o.goImportAlias),
GoImportPath: protogen.GoImportPath(o.goImport),
})
}
}
}
}

func buildOverrides(message *messageInfo) {
// Skip pre-declared
if strings.HasPrefix(string(message.Desc.FullName()), "google.protobuf.") {
return
}

log.Log("processing message: %s", message.Desc.FullName())

for _, field := range message.Fields {
var override overrideParams
log.Log("processing field: %s", field.GoName)

if uop, uok := processUninterpretedOptions(field); uok {
override = uop
} else if eop, eok := processExtensions(field); eok {
override = eop
}

if override.goType != "" {
if _, ok := overrideFields[message.GoIdent.GoName]; !ok {
overrideFields[message.GoIdent.GoName] = make(map[string]overrideParams)
}
overrideFields[message.GoIdent.GoName][field.GoName] = override
}
}
}

func processExtensions(field *protogen.Field) (overrideParams, bool) {
var override overrideParams
var ok bool
for _, ex := range field.Desc.Options().(*descriptorpb.FieldOptions).GetExtensionFields() {
switch string(ex.Type().TypeDescriptor().FullName()) {
case impl.FieldOptionGoType:
override.goType = ex.Value().String()
ok = true
case impl.FieldOptionGoImport:
override.goImport = ex.Value().String()
ok = true
case impl.FieldOptionGoImportAlias:
override.goImportAlias = ex.Value().String()
ok = true
}
}
return override, ok
}

func processUninterpretedOptions(field *protogen.Field) (overrideParams, bool) {
var override overrideParams
var ok bool
for _, o := range field.Desc.Options().(*descriptorpb.FieldOptions).GetUninterpretedOption() {
for _, namePart := range o.Name {
if namePart != nil {
log.Log("message field: %s", namePart.GetNamePart())
switch namePart.GetNamePart() {
case impl.FieldOptionGoType:
override.goType = string(o.GetStringValue())
ok = true
case impl.FieldOptionGoImport:
override.goImport = string(o.GetStringValue())
ok = true
case impl.FieldOptionGoImportAlias:
override.goImportAlias = string(o.GetStringValue())
ok = true
}
}
}
}

return override, ok
}

// genStandaloneComments prints all leading comments for a FileDescriptorProto
// location identified by the field number n.
func genStandaloneComments(g *protogen.GeneratedFile, f *fileInfo, n int32) {
Expand Down Expand Up @@ -419,7 +523,7 @@ func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, fie
if pointer {
goType = "*" + goType
}
goType = goTypeOverride(goType)
goType, _ = goTypeOverride(goType, m.GoIdent.GoName, field.GoName)
tags := structTags{
{"protobuf", fieldProtobufTagValue(field)},
{"json", fieldJSONTagValue(field)},
Expand Down Expand Up @@ -459,7 +563,7 @@ func genMessageDefaultDecls(g *protogen.GeneratedFile, f *fileInfo, m *messageIn
}
name := "Default_" + m.GoIdent.GoName + "_" + field.GoName
goType, _ := fieldGoType(g, f, field)
goType = goTypeOverride(goType)
goType, _ = goTypeOverride(goType, m.GoIdent.GoName, field.GoName)
defVal := field.Desc.Default()
switch field.Desc.Kind() {
case protoreflect.StringKind:
Expand Down Expand Up @@ -582,8 +686,9 @@ func genMessageGetterMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageI

// Getter for message field.
goType, pointer := fieldGoType(g, f, field)
goType = goTypeOverride(goType)
defaultValue := fieldDefaultValue(g, f, m, field)
var overwritten bool
goType, overwritten = goTypeOverride(goType, m.GoIdent.GoName, field.GoName)
defaultValue := fieldDefaultValue(g, f, m, field, goType, overwritten)
g.Annotate(m.GoIdent.GoName+".Get"+field.GoName, field.Location)
leadingComments := appendDeprecationSuffix("",
field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
Expand Down Expand Up @@ -709,7 +814,10 @@ func fieldProtobufTagValue(field *protogen.Field) string {
return tag.Marshal(field.Desc, enumName)
}

func fieldDefaultValue(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field) string {
func fieldDefaultValue(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field, goType string, overwritten bool) string {
if overwritten {
return overwrittenDefault(goType)
}
if field.Desc.IsList() {
return "nil"
}
Expand All @@ -732,8 +840,8 @@ func fieldDefaultValue(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, f
if field.Desc.HasOptionalKeyword() {
return "nil"
} else {
goType := g.QualifiedGoIdent(field.Message.GoIdent)
goType = goTypeOverride(goType)
// goType := g.QualifiedGoIdent(field.Message.GoIdent)
// goType, _ = goTypeOverride(goType, m.GoIdent.GoName, field.GoName)
return goType + "{}"
}
case protoreflect.EnumKind:
Expand All @@ -751,6 +859,28 @@ func fieldDefaultValue(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, f
}
}

func overwrittenDefault(goType string) string {
switch goType {
case "bool":
return "false"
case "uint", "uint8", "uint16", "uint32", "uint64", "int", "int8", "int16", "int32", "int64", "float32", "float64", "uintptr", "byte", "rune":
return "0"
case "complex64":
return "complex64(0)"
case "complex128":
return "complex128(0)"
case "any", "interface{}":
return "nil"
}
if strings.HasPrefix(goType, "map") ||
strings.HasPrefix(goType, "*") ||
strings.HasPrefix(goType, "[]") {
// TODO implement custom interface in the future
return "nil"
}
return goType + "{}"
}

func fieldJSONTagValue(field *protogen.Field) string {
return string(field.Desc.Name()) + ",omitempty"
}
Expand Down Expand Up @@ -921,16 +1051,17 @@ func (c trailingComment) String() string {
return s
}

func goTypeOverride(goType string) string {
func goTypeOverride(goType string, msgName string, fieldName string) (string, bool) {
if TypeOverride {
switch goType {
case "TimeTime":
return "time.Time"
}

if strings.Contains(goType, "RepeatedString") {
return strings.ReplaceAll(goType, "RepeatedString", "[]string")
// TODO check the case when goType is a map
if oMsg, okMsg := overrideFields[msgName]; okMsg {
if o, okField := oMsg[fieldName]; okField {
return o.goType, true
}
}
// if strings.Contains(goType, "RepeatedString") {
// return strings.ReplaceAll(goType, "RepeatedString", "[]string")
// }
}
return goType
return goType, false
}
16 changes: 14 additions & 2 deletions cmd/protoc-gen-go/internal_gengo/main_test.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
package internal_gengo

import (
"bytes"
"io"
"os"
"testing"

"github.com/infiniteloopcloud/protoc-gen-go-types/parser"
)

func TestGenerateFile(t *testing.T) {
gen, err := parser.Parse("./test_data/test.proto")
t.Setenv("TYPE_OVERRIDE", "true")
gen, err := parser.Parse("google/protobuf/descriptor.proto", "./test_data/config.proto", "./test_data/test.proto")
if err != nil {
t.Fatal(err)
}

for _, f := range gen.Files {
if f.Generate {
GenerateFile(gen, f)
content, err := GenerateFile(gen, f).Content()
if err != nil {
t.Fatal(err)
}
f, err := os.Create("./test_data/" + f.GeneratedFilenamePrefix + ".pb.go")
if err != nil {
t.Fatal(err)
}
io.Copy(f, bytes.NewReader(content))
}
}
}
13 changes: 13 additions & 0 deletions cmd/protoc-gen-go/internal_gengo/test_data/config.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
syntax = "proto3";

package proto;

option go_package = ".;proto";

import "google/protobuf/descriptor.proto";

extend google.protobuf.FieldOptions {
optional string go_type = 1001;
optional string go_import = 1002;
optional string go_import_alias = 1003;
}
36 changes: 24 additions & 12 deletions cmd/protoc-gen-go/internal_gengo/test_data/test.proto
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
syntax = "proto3";

package proto;

option go_package = ".;proto";

package proto;
//import "./test_data/config.proto";
// to test with other languages, switch to this import and run
// cd cmd/protoc-gen-go/internal_gengo
// TYPE_OVERRIDE=true DEBUG=true protoc --go_out=test_data -I test_data test_data/test.proto
import "config.proto";

// Or use the following hard coded field option extension
//import "google/protobuf/descriptor.proto";
//
//extend google.protobuf.FieldOptions {
// optional string go_type = 1001;
// optional string go_import = 1002;
// optional string go_import_alias = 1003;
//}

message RepeatedString {
}

message Test {
TimeTime created_at = 1;
map<uint64, RepeatedString> map_field = 2;
int64 created_at = 1 [(go_type) = "time.Time", (go_import) = "time"];
map<uint64, RepeatedString> map_field = 2 [(go_type) = "map[uint64][]string", (go_import) = ""];
string test = 3;
String other = 43;
string optStr = 4 [(go_type) = "null.String", (go_import) = "github.com/volatiletech/null/v9", (go_import_alias) = "null"];
int32 optInt = 5 [(go_type) = "null.Int32"];
int32 optBigInt = 6 [(go_type) = "null.Int64", (go_import) = "github.com/volatiletech/null/v9", (go_import_alias) = "null"];
}

message TimeTime {}

message RepeatedString {}

message String {
string string = 2;
}
Loading

0 comments on commit 502726f

Please sign in to comment.