Skip to content

Commit

Permalink
add option to set import alias manually
Browse files Browse the repository at this point in the history
extend the cmd with type and import overrides
  • Loading branch information
borosr committed Oct 28, 2022
1 parent 42d4020 commit 7d9853d
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 29 deletions.
86 changes: 74 additions & 12 deletions cmd/protoc-gen-go/internal_gengo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,21 @@ import (
const (
EnvSkipProtobufSpecific = "SKIP_PROTOBUF_SPECIFIC"
EnvTypeOverride = "TYPE_OVERRIDE"

fieldOptionGoType = "go_type"
fieldOptionGoImport = "go_import"
fieldOptionGoImportAlias = "go_import_alias"
)

type overrideParams struct {
goType string
goImport string
goImportAlias string
}

// overrideFields stores all the found messages which are created to override types
var overrideFields = make(map[string]map[string]overrideParams)

// SupportedFeatures reports the set of supported protobuf language features.
var SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)

Expand Down Expand Up @@ -103,6 +116,10 @@ func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated
g.P()
}

for _, message := range f.allMessages {
buildOverrides(message)
}
genOverrideImports(g)
for i, imps := 0, f.Desc.Imports(); i < imps.Len(); i++ {
genImport(gen, g, f, imps.Get(i))
}
Expand All @@ -123,6 +140,50 @@ func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.Generated
return g
}

func genOverrideImports(g *protogen.GeneratedFile) {
for _, overrideMessage := range overrideFields {
for _, o := range overrideMessage {
if o.goImport != "" {
g.QualifiedGoIdent(protogen.GoIdent{
GoImportAlias: protogen.GoPackageName(o.goImportAlias),
GoImportPath: protogen.GoImportPath(o.goImport),
})
}
}
}
}

func buildOverrides(message *messageInfo) {
// Skip pre-declared
if strings.HasPrefix(string(message.Desc.FullName()), "google.protobuf.") {
return
}

for _, field := range message.Fields {
var override overrideParams
for _, o := range field.Desc.Options().(*descriptorpb.FieldOptions).GetUninterpretedOption() {
for _, namePart := range o.Name {
if namePart != nil {
switch namePart.GetNamePart() {
case fieldOptionGoType:
override.goType = string(o.GetStringValue())
case fieldOptionGoImport:
override.goImport = string(o.GetStringValue())
case fieldOptionGoImportAlias:
override.goImportAlias = string(o.GetStringValue())
}
}
}
}
if override.goType != "" {
if _, ok := overrideFields[message.GoIdent.GoName]; !ok {
overrideFields[message.GoIdent.GoName] = make(map[string]overrideParams)
}
overrideFields[message.GoIdent.GoName][field.GoName] = override
}
}
}

// genStandaloneComments prints all leading comments for a FileDescriptorProto
// location identified by the field number n.
func genStandaloneComments(g *protogen.GeneratedFile, f *fileInfo, n int32) {
Expand Down Expand Up @@ -419,7 +480,7 @@ func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, fie
if pointer {
goType = "*" + goType
}
goType = goTypeOverride(goType)
goType = goTypeOverride(goType, m.GoIdent.GoName, field.GoName)
tags := structTags{
{"protobuf", fieldProtobufTagValue(field)},
{"json", fieldJSONTagValue(field)},
Expand Down Expand Up @@ -459,7 +520,7 @@ func genMessageDefaultDecls(g *protogen.GeneratedFile, f *fileInfo, m *messageIn
}
name := "Default_" + m.GoIdent.GoName + "_" + field.GoName
goType, _ := fieldGoType(g, f, field)
goType = goTypeOverride(goType)
goType = goTypeOverride(goType, m.GoIdent.GoName, field.GoName)
defVal := field.Desc.Default()
switch field.Desc.Kind() {
case protoreflect.StringKind:
Expand Down Expand Up @@ -582,7 +643,7 @@ func genMessageGetterMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageI

// Getter for message field.
goType, pointer := fieldGoType(g, f, field)
goType = goTypeOverride(goType)
goType = goTypeOverride(goType, m.GoIdent.GoName, field.GoName)
defaultValue := fieldDefaultValue(g, f, m, field)
g.Annotate(m.GoIdent.GoName+".Get"+field.GoName, field.Location)
leadingComments := appendDeprecationSuffix("",
Expand Down Expand Up @@ -733,7 +794,7 @@ func fieldDefaultValue(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, f
return "nil"
} else {
goType := g.QualifiedGoIdent(field.Message.GoIdent)
goType = goTypeOverride(goType)
goType = goTypeOverride(goType, m.GoIdent.GoName, field.GoName)
return goType + "{}"
}
case protoreflect.EnumKind:
Expand Down Expand Up @@ -921,16 +982,17 @@ func (c trailingComment) String() string {
return s
}

func goTypeOverride(goType string) string {
func goTypeOverride(goType string, msgName string, fieldName string) string {
if TypeOverride {
switch goType {
case "TimeTime":
return "time.Time"
}

if strings.Contains(goType, "RepeatedString") {
return strings.ReplaceAll(goType, "RepeatedString", "[]string")
// TODO check the case when goType is a map
if oMsg, okMsg := overrideFields[msgName]; okMsg {
if o, okField := oMsg[fieldName]; okField {
return o.goType
}
}
// if strings.Contains(goType, "RepeatedString") {
// return strings.ReplaceAll(goType, "RepeatedString", "[]string")
// }
}
return goType
}
16 changes: 14 additions & 2 deletions cmd/protoc-gen-go/internal_gengo/main_test.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
package internal_gengo

import (
"bytes"
"io"
"os"
"testing"

"github.com/infiniteloopcloud/protoc-gen-go-types/parser"
)

func TestGenerateFile(t *testing.T) {
gen, err := parser.Parse("./test_data/test.proto")
t.Setenv("TYPE_OVERRIDE", "true")
gen, err := parser.Parse("google/protobuf/descriptor.proto", "./test_data/config.proto", "./test_data/test.proto")
if err != nil {
t.Fatal(err)
}

for _, f := range gen.Files {
if f.Generate {
GenerateFile(gen, f)
content, err := GenerateFile(gen, f).Content()
if err != nil {
t.Fatal(err)
}
f, err := os.Create("./test_data/" + f.GeneratedFilenamePrefix + ".pb.go")
if err != nil {
t.Fatal(err)
}
io.Copy(f, bytes.NewReader(content))
}
}
}
11 changes: 11 additions & 0 deletions cmd/protoc-gen-go/internal_gengo/test_data/config.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
syntax = "proto3";

package proto;

import "google/protobuf/descriptor.proto";

extend google.protobuf.FieldOptions {
optional string go_type = 1000;
optional string go_import = 1001;
optional string go_import_alias = 1002;
}
23 changes: 11 additions & 12 deletions cmd/protoc-gen-go/internal_gengo/test_data/test.proto
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
syntax = "proto3";

package proto;

option go_package = ".;proto";

package proto;
import "./test_data/config.proto";

message RepeatedString {
}

message Test {
TimeTime created_at = 1;
map<uint64, RepeatedString> map_field = 2;
int64 created_at = 1 [(go_type) = "time.Time", (go_import) = "time"];
map<uint64, RepeatedString> map_field = 2 [(go_type) = "map[uint64][]string", (go_import) = ""];
string test = 3;
String other = 43;
string optStr = 4 [(go_type) = "null.String", (go_import) = "github.com/volatiletech/null/v9", (go_import_alias) = "null"];
int32 optInt = 5 [(go_type) = "null.Int32"];
int32 optBigInt = 6 [(go_type) = "null.Int64", (go_import) = "github.com/volatiletech/null/v9", (go_import_alias) = "null"];
}

message TimeTime {}

message RepeatedString {}

message String {
string string = 2;
}
12 changes: 9 additions & 3 deletions compiler/protogen/protogen.go
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,12 @@ func (g *GeneratedFile) QualifiedGoIdent(ident GoIdent) string {
if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
return string(packageName) + "." + ident.GoName
}
packageName := cleanPackageName(path.Base(string(ident.GoImportPath)))
var packageName GoPackageName
if ident.GoImportAlias == "" {
packageName = cleanPackageName(path.Base(string(ident.GoImportPath)))
} else {
packageName = ident.GoImportAlias
}
for i, orig := 1, packageName; g.usedPackageNames[packageName]; i++ {
packageName = orig + GoPackageName(strconv.Itoa(i))
}
Expand Down Expand Up @@ -1163,8 +1168,9 @@ func (g *GeneratedFile) metaFile(content []byte) (string, error) {
// A GoIdent is a Go identifier, consisting of a name and import path.
// The name is a single identifier and may not be a dot-qualified selector.
type GoIdent struct {
GoName string
GoImportPath GoImportPath
GoName string
GoImportAlias GoPackageName
GoImportPath GoImportPath
}

func (id GoIdent) String() string { return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName) }
Expand Down

0 comments on commit 7d9853d

Please sign in to comment.