Skip to content

Commit

Permalink
add support for full merge of base document (#38)
Browse files Browse the repository at this point in the history
* add support for full merge of base document

* add back gen code

* fix spacing

* revert base update

* update based off feedback

* remove unused config param

* revert buf.gen.yml

* update tests
  • Loading branch information
castlemilk authored Oct 30, 2024
1 parent 062704a commit 8442f9e
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 22 deletions.
1 change: 1 addition & 0 deletions converter/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ func WithBaseOpenAPI(baseOpenAPI []byte) Option {
}
}


// WithAllowGET sets a file to use as a base for all OpenAPI files.
func WithAllowGET(allowGet bool) Option {
return func(g *generator) error {
Expand Down
2 changes: 1 addition & 1 deletion example.base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ openapi: 3.1.0
info:
description: "THIS IS FROM BASE!"
title: WHOOOOP
version: v1.0.0
version: v1.0.0
69 changes: 69 additions & 0 deletions internal/converter/converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sudorandom/protoc-gen-connect-openapi/internal/converter"
"github.com/sudorandom/protoc-gen-connect-openapi/internal/converter/options"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/descriptorpb"
pluginpb "google.golang.org/protobuf/types/pluginpb"
"gopkg.in/yaml.v3"
)

Expand Down Expand Up @@ -188,3 +191,69 @@ func makeOutputPath(protofile, format string) string {
dir, file := filepath.Split(strings.TrimSuffix(protofile, filepath.Ext(protofile)) + ".openapi." + format)
return filepath.Join(dir, "output", file)
}

func TestConvertWithOptions(t *testing.T) {
t.Run("with base file", func(t *testing.T) {
baseYAML := `
openapi: 3.1.0
info:
title: Base API
version: 1.0.0
x-logo:
url: https://example.com/logo.png
paths:
/example/api/path:
post:
x-code-samples:
- language: shell
label: example-api-path
source: |
curl -X POST https://api.example.com/example/api/path \
-H "Content-Type: application/json" \
-d '{"email": "user@example.com"}'
`
opts := options.Options{
Path: "test.openapi.yaml",
Format: "yaml",
BaseOpenAPI: []byte(baseYAML),
}

req := &pluginpb.CodeGeneratorRequest{
ProtoFile: []*descriptorpb.FileDescriptorProto{
{
Name: proto.String("test.proto"),
Package: proto.String("test"),
MessageType: []*descriptorpb.DescriptorProto{
{Name: proto.String("TestMessage")},
},
Service: []*descriptorpb.ServiceDescriptorProto{
{
Name: proto.String("ExampleApiService"),
Method: []*descriptorpb.MethodDescriptorProto{
{
Name: proto.String("ExampleApiPath"),
InputType: proto.String(".test.TestMessage"),
OutputType: proto.String(".test.TestMessage"),
},
},
},
},
},
},
FileToGenerate: []string{"test.proto"},
}

resp, err := converter.ConvertWithOptions(req, opts)
require.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.File, 1)

content := resp.File[0].GetContent()
assert.Contains(t, content, "x-logo:")
assert.Contains(t, content, "url: https://example.com/logo.png")
assert.Contains(t, content, "x-code-samples:")

// Check that the generated content is merged with the base file
assert.Contains(t, content, "TestMessage")
})
}
203 changes: 182 additions & 21 deletions internal/converter/paths.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,200 @@ func addPathItemsFromFile(opts options.Options, fd protoreflect.FileDescriptor,
for j := 0; j < methods.Len(); j++ {
method := methods.Get(j)
pathItems := googleapi.MakePathItems(opts, method)
for pair := pathItems.First(); pair != nil; pair = pair.Next() {
path, item := pair.Key(), pair.Value()
if existing, ok := paths.PathItems.Get(pair.Key()); !ok {
paths.PathItems.Set(path, item)

// Helper function to update or set path items
addPathItem := func(path string, newItem *v3.PathItem) {
if existing, ok := paths.PathItems.Get(path); !ok {
paths.PathItems.Set(path, newItem)
} else {
if item.Get != nil {
existing.Get = item.Get
}
if item.Post != nil {
existing.Post = item.Post
}
if item.Delete != nil {
existing.Delete = item.Delete
}
if item.Put != nil {
existing.Put = item.Put
}
if item.Patch != nil {
existing.Patch = item.Patch
}
mergePathItems(existing, newItem)
paths.PathItems.Set(path, existing)
}
}
// No google.api annotations for this method, so default to the ConnectRPC/gRPC path

// Update path items from google.api annotations
for pair := pathItems.First(); pair != nil; pair = pair.Next() {
addPathItem(pair.Key(), pair.Value())
}

// Default to ConnectRPC/gRPC path if no google.api annotations
if pathItems == nil || pathItems.Len() == 0 {
paths.PathItems.Set("/"+string(service.FullName())+"/"+string(method.Name()), methodToPathItem(opts, method))
path := "/" + string(service.FullName()) + "/" + string(method.Name())
addPathItem(path, methodToPathItem(opts, method))
}
}
}

return nil
}

func mergePathItems(existing, new *v3.PathItem) {
// Merge operations
operations := []struct {
existingOp **v3.Operation
newOp *v3.Operation
}{
{&existing.Get, new.Get},
{&existing.Post, new.Post},
{&existing.Put, new.Put},
{&existing.Delete, new.Delete},
{&existing.Options, new.Options},
{&existing.Head, new.Head},
{&existing.Patch, new.Patch},
{&existing.Trace, new.Trace},
}

for _, op := range operations {
if op.newOp != nil {
mergeOperation(op.existingOp, op.newOp)
}
}

// Merge other fields
if new.Summary != "" {
existing.Summary = new.Summary
}
if new.Description != "" {
existing.Description = new.Description
}
existing.Servers = append(existing.Servers, new.Servers...)
existing.Parameters = append(existing.Parameters, new.Parameters...)

// Merge extensions
for pair := new.Extensions.First(); pair != nil; pair = pair.Next() {
if _, ok := existing.Extensions.Get(pair.Key()); !ok {
existing.Extensions.Set(pair.Key(), pair.Value())
}
}
}

func mergeOperation(existing **v3.Operation, new *v3.Operation) {
if *existing == nil {
*existing = new
return
}
// Merge operation fields
if new.Summary != "" {
(*existing).Summary = new.Summary
}
if new.Description != "" {
(*existing).Description = new.Description
}
(*existing).Tags = append((*existing).Tags, new.Tags...)
(*existing).Parameters = append((*existing).Parameters, new.Parameters...)
if new.RequestBody != nil {
(*existing).RequestBody = new.RequestBody
}
if new.Responses != nil {
mergeResponses((*existing).Responses, new.Responses)
}
if new.Deprecated != nil {
(*existing).Deprecated = new.Deprecated
}

// Add support for additional Operation fields
if new.Callbacks != nil {
if (*existing).Callbacks == nil {
(*existing).Callbacks = orderedmap.New[string, *v3.Callback]()
}
for pair := new.Callbacks.First(); pair != nil; pair = pair.Next() {
if _, ok := (*existing).Callbacks.Get(pair.Key()); !ok {
(*existing).Callbacks.Set(pair.Key(), pair.Value())
}
}
}

if new.Security != nil {
(*existing).Security = append((*existing).Security, new.Security...)
}

if new.Servers != nil {
(*existing).Servers = append((*existing).Servers, new.Servers...)
}

if new.ExternalDocs != nil {
(*existing).ExternalDocs = new.ExternalDocs
}

// Merge extensions
for pair := new.Extensions.First(); pair != nil; pair = pair.Next() {
if _, ok := (*existing).Extensions.Get(pair.Key()); !ok {
(*existing).Extensions.Set(pair.Key(), pair.Value())
}
}
}

func mergeResponses(existing, new *v3.Responses) {
if existing == nil || new == nil {
return
}

// Merge response codes
for pair := new.Codes.First(); pair != nil; pair = pair.Next() {
code := pair.Key()
if existingResponse, ok := existing.Codes.Get(code); !ok {
existing.Codes.Set(code, pair.Value())
} else {
mergeResponse(existingResponse, pair.Value())
}
}

// Merge default response
if new.Default != nil {
if existing.Default == nil {
existing.Default = new.Default
} else {
mergeResponse(existing.Default, new.Default)
}
}
}

func mergeResponse(existing, new *v3.Response) {
if new.Description != "" {
existing.Description = new.Description
}

// Merge Content
for pair := new.Content.First(); pair != nil; pair = pair.Next() {
contentType := pair.Key()
mediaType := pair.Value()
if _, ok := existing.Content.Get(contentType); !ok {
existing.Content.Set(contentType, mediaType)
}
}

// Merge Headers
if new.Headers != nil {
if existing.Headers == nil {
existing.Headers = orderedmap.New[string, *v3.Header]()
}
for pair := new.Headers.First(); pair != nil; pair = pair.Next() {
if _, ok := existing.Headers.Get(pair.Key()); !ok {
existing.Headers.Set(pair.Key(), pair.Value())
}
}
}

// Merge Links
if new.Links != nil {
if existing.Links == nil {
existing.Links = orderedmap.New[string, *v3.Link]()
}
for pair := new.Links.First(); pair != nil; pair = pair.Next() {
if _, ok := existing.Links.Get(pair.Key()); !ok {
existing.Links.Set(pair.Key(), pair.Value())
}
}
}

// Merge Extensions
for pair := new.Extensions.First(); pair != nil; pair = pair.Next() {
if _, ok := existing.Extensions.Get(pair.Key()); !ok {
existing.Extensions.Set(pair.Key(), pair.Value())
}
}
}

func methodToOperaton(opts options.Options, method protoreflect.MethodDescriptor, returnGet bool) *v3.Operation {
fd := method.ParentFile()
service := method.Parent().(protoreflect.ServiceDescriptor)
Expand Down

0 comments on commit 8442f9e

Please sign in to comment.