Skip to content

Commit

Permalink
Merge pull request horizon-games#1 from ravenops/feature/fix-import-p…
Browse files Browse the repository at this point in the history
…athing

Fix import pathing
  • Loading branch information
ccpost authored Jul 5, 2019
2 parents a30af02 + 5554f86 commit fe7ae49
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 182 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
*.ts
*.js
node_modules

.idea/
2 changes: 1 addition & 1 deletion dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (d *dependencyResolver) Resolve(typeName string) (*descriptor.FileDescripto
func (d *dependencyResolver) TypeName(fd *descriptor.FileDescriptorProto, typeName string) string {
orig, err := d.Resolve(fullTypeName(fd, typeName))
if err == nil {
if !samePackage(fd, orig) {
if fd.GetPackage() != orig.GetPackage() {
return importName(orig) + "." + typeName
}
}
Expand Down
172 changes: 73 additions & 99 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,49 +10,8 @@ import (
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
)

type packageFile struct {
name string
pf []*protoFile
}

func (f *packageFile) addProto(pf *protoFile) {
f.pf = append(f.pf, pf)
}

func (f *packageFile) protoFile() *protoFile {
pf := &protoFile{
Imports: map[string]*importValues{},
Messages: []*messageValues{},
Services: []*serviceValues{},
Enums: []*enumValues{},
}
for i := range f.pf {
for j := range f.pf[i].Imports {
pf.Imports[j] = f.pf[i].Imports[j]
}
pf.Messages = append(pf.Messages, f.pf[i].Messages...)
pf.Services = append(pf.Services, f.pf[i].Services...)
pf.Enums = append(pf.Enums, f.pf[i].Enums...)
}
return pf
}

var (
packageFiles = map[string]*packageFile{}
)

func addProtoToPackage(fileName string, pf *protoFile) {
if _, ok := packageFiles[fileName]; !ok {
packageFiles[fileName] = &packageFile{name: fileName}
}
packageFiles[fileName].addProto(pf)
}

func samePackage(a *descriptor.FileDescriptorProto, b *descriptor.FileDescriptorProto) bool {
if a.GetPackage() != b.GetPackage() {
return false
}
return true
func sameFile(a *descriptor.FileDescriptorProto, b *descriptor.FileDescriptorProto) bool {
return a.GetName() == b.GetName()
}

func fullTypeName(fd *descriptor.FileDescriptorProto, typeName string) string {
Expand All @@ -71,16 +30,18 @@ func generate(req *plugin.CodeGeneratorRequest) (*plugin.CodeGeneratorResponse,
},
}

outputFiles := make(map[string][]*protoFile)
protoFiles := req.GetProtoFile()
for i := range protoFiles {
file := protoFiles[i]

for _, file := range protoFiles {
pfile := &protoFile{
Imports: map[string]*importValues{},
Messages: []*messageValues{},
Services: []*serviceValues{},
Enums: []*enumValues{},
Output: tsFileName(file),
RelativeImportBase: relativeImportBase(file),
Imports: map[string]*importValues{},
Messages: []*messageValues{},
Services: []*serviceValues{},
Enums: []*enumValues{},
}
outputFiles[tsImportPath(file)] = append(outputFiles[tsImportPath(file)], pfile)

// Add enum
for _, enum := range file.GetEnumType() {
Expand All @@ -102,8 +63,30 @@ func generate(req *plugin.CodeGeneratorRequest) (*plugin.CodeGeneratorResponse,
}

// Add messages
for _, message := range file.GetMessageType() {
name := message.GetName()
type collectMsg struct {
Name string
FD *descriptor.DescriptorProto
}
var allMsgs []collectMsg
// Recurse through message definitions first
var collectMsgDefs func(msg *descriptor.DescriptorProto, parents []string)
collectMsgDefs = func(msg *descriptor.DescriptorProto, parents []string) {
parents = append(parents, msg.GetName())
allMsgs = append(allMsgs, collectMsg{
Name: strings.Join(parents, "_"),
FD: msg,
})
for _, m := range msg.GetNestedType() {
collectMsgDefs(m, parents)
}
}
for _, msg := range file.GetMessageType() {
collectMsgDefs(msg, nil)
}
// Parse them all in flattened form and add to the list
for _, collect := range allMsgs {
message := collect.FD
name := collect.Name
tsInterface := typeToInterface(name)
jsonInterface := typeToJSONInterface(name)

Expand All @@ -121,12 +104,6 @@ func generate(req *plugin.CodeGeneratorRequest) (*plugin.CodeGeneratorResponse,
NestedEnums: []*enumValues{},
}

if len(message.GetNestedType()) > 0 {
// TODO: add support for nested messages
// https://developers.google.com/protocol-buffers/docs/proto#nested
log.Printf("warning: nested messages are not supported yet")
}

// Add nested enums
for _, enum := range message.GetEnumType() {
e := &enumValues{
Expand All @@ -146,18 +123,14 @@ func generate(req *plugin.CodeGeneratorRequest) (*plugin.CodeGeneratorResponse,

// Add message fields
for _, field := range message.GetField() {
typeName := resolver.TypeName(file, singularFieldType(message, field))
fp, err := resolver.Resolve(field.GetTypeName())
if err == nil {
if !samePackage(fp, file) {
pfile.Imports[fp.GetPackage()] = &importValues{
Name: importName(fp),
Path: importPath(file, fp.GetPackage()),
}
if !sameFile(fp, file) {
pfile.AddImport(fp, typeName)
}
}

typeName := resolver.TypeName(file, singularFieldType(message, field))

v.Fields = append(v.Fields, &fieldValues{
Name: field.GetName(),
Field: camelCase(field.GetName()),
Expand All @@ -183,56 +156,64 @@ func generate(req *plugin.CodeGeneratorRequest) (*plugin.CodeGeneratorResponse,
}

for _, method := range service.GetMethod() {
inputType := resolver.TypeName(file, removePkg(method.GetInputType()))
outputType := resolver.TypeName(file, removePkg(method.GetOutputType()))
{
fp, err := resolver.Resolve(method.GetInputType())
if err == nil {
if !samePackage(fp, file) {
pfile.Imports[fp.GetPackage()] = &importValues{
Name: importName(fp),
Path: importPath(file, fp.GetPackage()),
}
if !sameFile(fp, file) {
pfile.AddImport(fp, inputType)
}
}
}

{
fp, err := resolver.Resolve(method.GetOutputType())
if err == nil {
if !samePackage(fp, file) {
pfile.Imports[fp.GetPackage()] = &importValues{
Name: importName(fp),
Path: importPath(file, fp.GetPackage()),
}
if !sameFile(fp, file) {
pfile.AddImport(fp, outputType)
}
}
}

v.Methods = append(v.Methods, &serviceMethodValues{
Name: method.GetName(),
InputType: resolver.TypeName(file, removePkg(method.GetInputType())),
OutputType: resolver.TypeName(file, removePkg(method.GetOutputType())),
InputType: inputType,
OutputType: outputType,
})
}

pfile.Services = append(pfile.Services, v)
}

// Add to appropriate file
addProtoToPackage(tsFileName(file), pfile)
}

for packageName := range packageFiles {
pf := packageFiles[packageName]
for tsPath, pff := range outputFiles {
ev := &exportValues{}

for _, pf := range pff {
ev.Exports = append(ev.Exports, strings.TrimSuffix(path.Base(pf.Output), ".ts"))

// Compile to typescript
content, err := pf.Compile()
if err != nil {
log.Fatal("could not compile template: ", err)
}

// Add to file list
res.File = append(res.File, &plugin.CodeGeneratorResponse_File{
Name: &pf.Output,
Content: &content,
})
}

// Compile to typescript
content, err := pf.protoFile().Compile()
content, err := ev.Compile()
if err != nil {
log.Fatal("could not compile template: ", err)
}

// Add to file list
name := path.Join(tsPath, "index.ts")
res.File = append(res.File, &plugin.CodeGeneratorResponse_File{
Name: &pf.name,
Name: &name,
Content: &content,
})
}
Expand Down Expand Up @@ -281,24 +262,17 @@ func tsImportName(name string) string {
return base[0 : len(base)-len(path.Ext(base))]
}

func tsImportPath(name string) string {
base := path.Base(name)
name = name[0 : len(name)-len(path.Ext(base))]
return name
func tsImportPath(fd *descriptor.FileDescriptorProto) string {
return path.Join(strings.Split(fd.GetPackage(), ".")...)
}

func importPath(fd *descriptor.FileDescriptorProto, name string) string {
// TODO: how to resolve relative paths?
return tsImportPath(name)
func relativeImportBase(fd *descriptor.FileDescriptorProto) string {
return strings.Repeat("../", len(strings.Split(tsImportPath(fd), "/")))
}

func tsFileName(fd *descriptor.FileDescriptorProto) string {
packageName := fd.GetPackage()
if packageName == "" {
packageName = path.Base(fd.GetName())
}
name := path.Join(path.Dir(fd.GetName()), packageName)
return tsImportPath(name) + ".ts"
filename := strings.TrimSuffix(path.Base(fd.GetName()), path.Ext(fd.GetName())) + ".ts"
return path.Join(tsImportPath(fd), filename)
}

func singularFieldType(m *descriptor.DescriptorProto, f *descriptor.FieldDescriptorProto) string {
Expand Down
Loading

0 comments on commit fe7ae49

Please sign in to comment.