Skip to content

Commit

Permalink
feat: add protoloader package
Browse files Browse the repository at this point in the history
Loads the package-level file descriptor set during runtime for
configurable proto messages.

To be used for code generation of protobuf/Spanner data conversion, and
a protobuf-first query API.
  • Loading branch information
odsod committed Jan 17, 2021
1 parent 97d16ca commit 9180786
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 37 deletions.
22 changes: 16 additions & 6 deletions aip-spanner-go.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
databases:
- name: music
schema:
- "testdata/migrations/music/*.up.sql"
package:
name: musicdb
path: ./internal/examples/musicdb

- name: freight
schema:
- "testdata/migrations/freight/*.up.sql"
package:
name: freightdb
path: ./internal/examples/freightdb

- name: music
schema:
- "testdata/migrations/music/*.up.sql"
package:
name: musicdb
path: ./internal/examples/musicdb
resources:
- message: go.einride.tech/aip/examples/proto/gen/einride/example/freight/v1.Shipper
table: shippers

- message: go.einride.tech/aip/examples/proto/gen/einride/example/freight/v1.Site
table: sites

- message: go.einride.tech/aip/examples/proto/gen/einride/example/freight/v1.Shipment
table: shipments
49 changes: 18 additions & 31 deletions cmd/aip-spanner-go/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"go.einride.tech/aip-spanner/internal/codegen"
"go.einride.tech/aip-spanner/internal/codegen/databasecodegen"
"go.einride.tech/aip-spanner/internal/codegen/descriptorcodegen"
"go.einride.tech/aip-spanner/spanddl"
"go.einride.tech/aip-spanner/internal/config"
"gopkg.in/yaml.v2"
)

Expand All @@ -33,39 +33,26 @@ func main() {
log.Panic(err)
}
}()
var config struct {
Databases []struct {
Name string `yaml:"name"`
SchemaGlobs []string `yaml:"schema"`
Package struct {
Name string `yaml:"name"`
Path string `yaml:"path"`
} `yaml:"package"`
} `yaml:"databases"`
}
if err := yaml.NewDecoder(configFile).Decode(&config); err != nil {
var codeGenerationConfig config.CodeGenerationConfig
if err := yaml.NewDecoder(configFile).Decode(&codeGenerationConfig); err != nil {
log.Panic(err)
}
for _, databaseConfig := range config.Databases {
var db spanddl.Database
for _, schemaGlob := range databaseConfig.SchemaGlobs {
schemaFiles, err := filepath.Glob(schemaGlob)
for _, databaseConfig := range codeGenerationConfig.Databases {
db, err := databaseConfig.LoadDatabase()
if err != nil {
log.Panic(err)
}
for _, resourceConfig := range databaseConfig.Resources {
table, ok := db.Table(spansql.ID(resourceConfig.Table))
if !ok {
log.Panicf("unknown table %s in database %s", resourceConfig.Table, databaseConfig.Name)
}
messageDescriptor, err := resourceConfig.LoadMessageDescriptor()
if err != nil {
log.Panic(err)
}
for _, schemaFile := range schemaFiles {
schema, err := ioutil.ReadFile(schemaFile)
if err != nil {
log.Panic(err)
}
ddl, err := spansql.ParseDDL(schemaFile, string(schema))
if err != nil {
log.Panic(err)
}
if err := db.ApplyDDL(ddl); err != nil {
log.Panic(err)
}
}
// TODO: Use table and message descriptor for code generation.
_, _ = table, messageDescriptor
}
if err := os.MkdirAll(databaseConfig.Package.Path, 0o775); err != nil {
log.Panic(err)
Expand All @@ -77,7 +64,7 @@ func main() {
GeneratedBy: generatedBy,
})
descriptorcodegen.DatabaseDescriptorCodeGenerator{
Database: &db,
Database: db,
}.GenerateCode(f)
content, err := f.Content()
if err != nil {
Expand All @@ -94,7 +81,7 @@ func main() {
Package: databaseConfig.Package.Name,
GeneratedBy: generatedBy,
})
databasecodegen.DatabaseCodeGenerator{Database: &db}.GenerateCode(f)
databasecodegen.DatabaseCodeGenerator{Database: db}.GenerateCode(f)
content, err := f.Content()
if err != nil {
log.Panic(err)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
google.golang.org/api v0.36.0
google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7
google.golang.org/grpc v1.35.0
google.golang.org/protobuf v1.25.0
gopkg.in/yaml.v2 v2.4.0
gotest.tools/v3 v3.0.3
)
64 changes: 64 additions & 0 deletions internal/config/codegeneration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package config

import (
"fmt"
"strings"

"go.einride.tech/aip-spanner/internal/protoloader"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
)

// CodeGenerationConfig contains config for code generation.
type CodeGenerationConfig struct {
// Databases to generate code for.
Databases []DatabaseConfig `yaml:"databases"`
}

// ResourceConfig contains code generation config for a resource.
type ResourceConfig struct {
// Message contains the Go package path and message name to use for the resource.
// Example: go.einride.tech/aip/examples/proto/gen/einride/example/freight/v1.Shipper.
Message string `yaml:"message"`
// Table is the name of the table used for storing the resource.
Table string `yaml:"table"`
}

// LoadMessageDescriptor loads the protobuf descriptor for the configured message.
func (r *ResourceConfig) LoadMessageDescriptor() (protoreflect.MessageDescriptor, error) {
i := strings.LastIndexByte(r.Message, '.')
if i == -1 {
return nil, fmt.Errorf("load message descriptor: invalid message format %s", r.Message)
}
goImportPath, messageName := r.Message[:i], r.Message[i+1:]
files, err := protoloader.LoadFilesFromGoPackage(goImportPath)
if err != nil {
return nil, fmt.Errorf("load message descriptor %s: %w", r.Message, err)
}
var result protoreflect.MessageDescriptor
files.RangeFiles(func(file protoreflect.FileDescriptor) bool {
fileOptions, ok := file.Options().(*descriptorpb.FileOptions)
if !ok {
return true
}
goPackage := fileOptions.GetGoPackage()
if i := strings.LastIndexByte(goPackage, ';'); i != -1 {
goPackage = goPackage[:i]
}
if goPackage != goImportPath {
return true
}
for i := 0; i < file.Messages().Len(); i++ {
message := file.Messages().Get(i)
if message.Name() == protoreflect.Name(messageName) {
result = message
return false
}
}
return true
})
if result == nil {
return nil, fmt.Errorf("found no descriptor for message %s", r.Message)
}
return result, nil
}
55 changes: 55 additions & 0 deletions internal/config/database.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package config

import (
"fmt"
"io/ioutil"
"path/filepath"

"cloud.google.com/go/spanner/spansql"
"go.einride.tech/aip-spanner/spanddl"
)

// DatabaseConfig contains code generation config for a database.
type DatabaseConfig struct {
// Name of the database.
Name string `yaml:"name"`
// SchemaGlobs are read in ass
SchemaGlobs []string `yaml:"schema"`
// Package is the config for database's generated Go package.
Package GoPackageConfig `yaml:"package"`
// Resources are the config for the databases generated resource APIs.
Resources []ResourceConfig `yaml:"resources"`
}

// LoadDatabase loads the configured database.
func (c *DatabaseConfig) LoadDatabase() (*spanddl.Database, error) {
var db spanddl.Database
for _, schemaGlob := range c.SchemaGlobs {
schemaFiles, err := filepath.Glob(schemaGlob)
if err != nil {
return nil, fmt.Errorf("load database %s: %w", c.Name, err)
}
for _, schemaFile := range schemaFiles {
schema, err := ioutil.ReadFile(schemaFile)
if err != nil {
return nil, fmt.Errorf("load database %s: %w", c.Name, err)
}
ddl, err := spansql.ParseDDL(schemaFile, string(schema))
if err != nil {
return nil, fmt.Errorf("load database %s: %w", c.Name, err)
}
if err := db.ApplyDDL(ddl); err != nil {
return nil, fmt.Errorf("load database %s: %w", c.Name, err)
}
}
}
return &db, nil
}

// GoPackageConfig contains code generation config for a Go package.
type GoPackageConfig struct {
// Name is the package name.
Name string `yaml:"name"`
// Path is the package import path.
Path string `yaml:"path"`
}
2 changes: 2 additions & 0 deletions internal/config/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Package config contains configuration for the AIP Spanner Go code generator.
package config
88 changes: 88 additions & 0 deletions internal/protoloader/loader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package protoloader

import (
"bytes"
"encoding/base64"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"text/template"

"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
)

func LoadFilesFromGoPackage(goPackage string) (*protoregistry.Files, error) {
tmpDir, err := ioutil.TempDir(".", "protoloader*")
if err != nil {
return nil, fmt.Errorf("load proto files from Go package %s: %w", goPackage, err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
panic(fmt.Errorf("failed to clean up temporary dir: %s", tmpDir))
}
}()
filename := filepath.Join(tmpDir, "main.go")
f, err := os.Create(filename)
if err != nil {
return nil, fmt.Errorf("load proto files from Go package %s: %w", goPackage, err)
}
defer func() {
_ = f.Close()
}()
if err := mainTemplate.Execute(f, struct{ GoPackage string }{GoPackage: goPackage}); err != nil {
return nil, fmt.Errorf("load proto files from Go package %s: %w", goPackage, err)
}
cmd := exec.Command("go", "run", filename)
var stdout, stderr bytes.Buffer
cmd.Stdout, cmd.Stderr = &stdout, &stderr
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("go run %s: %s", filename, stderr.String())
}
data, err := base64.StdEncoding.DecodeString(stdout.String())
if err != nil {
return nil, fmt.Errorf("load proto files from Go package %s: %w", goPackage, err)
}
var fileSet descriptorpb.FileDescriptorSet
if err := proto.Unmarshal(data, &fileSet); err != nil {
return nil, fmt.Errorf("load proto files from Go package %s: %w", goPackage, err)
}
return protodesc.NewFiles(&fileSet)
}

// nolint: gochecknoglobals
var mainTemplate = template.Must(template.New("main").Parse(`
package main
import (
"encoding/base64"
"fmt"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
_ "{{.GoPackage}}" // package to load
)
func main() {
fileSet := &descriptorpb.FileDescriptorSet{
File: make([]*descriptorpb.FileDescriptorProto, 0, protoregistry.GlobalFiles.NumFiles()),
}
protoregistry.GlobalFiles.RangeFiles(func(file protoreflect.FileDescriptor) bool {
fileSet.File = append(fileSet.File, protodesc.ToFileDescriptorProto(file))
return true
})
data, err := proto.Marshal(fileSet)
if err != nil {
panic(err)
}
fmt.Print(base64.StdEncoding.EncodeToString(data))
}
`))
14 changes: 14 additions & 0 deletions internal/protoloader/loader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package protoloader

import (
"testing"

"gotest.tools/v3/assert"
)

func TestLoadFilesFromGoPackage(t *testing.T) {
t.Parallel()
files, err := LoadFilesFromGoPackage("go.einride.tech/aip/examples/proto/gen/einride/example/freight/v1")
assert.NilError(t, err)
assert.Assert(t, files.NumFiles() > 0)
}

0 comments on commit 9180786

Please sign in to comment.