diff --git a/converter/converter.go b/converter/converter.go index 7fea0e2..2233998 100644 --- a/converter/converter.go +++ b/converter/converter.go @@ -125,6 +125,7 @@ func WithBaseOpenAPI(baseOpenAPI []byte) Option { } } + // WithAllowGET sets a file to use as a base for all OpenAPI files. func WithAllowGET(allowGet bool) Option { return func(g *generator) error { diff --git a/example.base.yaml b/example.base.yaml index 634eb95..4148f72 100644 --- a/example.base.yaml +++ b/example.base.yaml @@ -2,4 +2,4 @@ openapi: 3.1.0 info: description: "THIS IS FROM BASE!" title: WHOOOOP - version: v1.0.0 \ No newline at end of file + version: v1.0.0 diff --git a/internal/converter/converter_test.go b/internal/converter/converter_test.go index 279823c..38d497d 100644 --- a/internal/converter/converter_test.go +++ b/internal/converter/converter_test.go @@ -18,7 +18,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/sudorandom/protoc-gen-connect-openapi/internal/converter" + "github.com/sudorandom/protoc-gen-connect-openapi/internal/converter/options" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/descriptorpb" + pluginpb "google.golang.org/protobuf/types/pluginpb" "gopkg.in/yaml.v3" ) @@ -188,3 +191,69 @@ func makeOutputPath(protofile, format string) string { dir, file := filepath.Split(strings.TrimSuffix(protofile, filepath.Ext(protofile)) + ".openapi." + format) return filepath.Join(dir, "output", file) } + +func TestConvertWithOptions(t *testing.T) { + t.Run("with base file", func(t *testing.T) { + baseYAML := ` +openapi: 3.1.0 +info: + title: Base API + version: 1.0.0 + x-logo: + url: https://example.com/logo.png +paths: + /example/api/path: + post: + x-code-samples: + - language: shell + label: example-api-path + source: | + curl -X POST https://api.example.com/example/api/path \ + -H "Content-Type: application/json" \ + -d '{"email": "user@example.com"}' +` + opts := options.Options{ + Path: "test.openapi.yaml", + Format: "yaml", + BaseOpenAPI: []byte(baseYAML), + } + + req := &pluginpb.CodeGeneratorRequest{ + ProtoFile: []*descriptorpb.FileDescriptorProto{ + { + Name: proto.String("test.proto"), + Package: proto.String("test"), + MessageType: []*descriptorpb.DescriptorProto{ + {Name: proto.String("TestMessage")}, + }, + Service: []*descriptorpb.ServiceDescriptorProto{ + { + Name: proto.String("ExampleApiService"), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: proto.String("ExampleApiPath"), + InputType: proto.String(".test.TestMessage"), + OutputType: proto.String(".test.TestMessage"), + }, + }, + }, + }, + }, + }, + FileToGenerate: []string{"test.proto"}, + } + + resp, err := converter.ConvertWithOptions(req, opts) + require.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.File, 1) + + content := resp.File[0].GetContent() + assert.Contains(t, content, "x-logo:") + assert.Contains(t, content, "url: https://example.com/logo.png") + assert.Contains(t, content, "x-code-samples:") + + // Check that the generated content is merged with the base file + assert.Contains(t, content, "TestMessage") + }) +} diff --git a/internal/converter/paths.go b/internal/converter/paths.go index 4fb8751..566c96c 100644 --- a/internal/converter/paths.go +++ b/internal/converter/paths.go @@ -21,32 +21,26 @@ func addPathItemsFromFile(opts options.Options, fd protoreflect.FileDescriptor, for j := 0; j < methods.Len(); j++ { method := methods.Get(j) pathItems := googleapi.MakePathItems(opts, method) - for pair := pathItems.First(); pair != nil; pair = pair.Next() { - path, item := pair.Key(), pair.Value() - if existing, ok := paths.PathItems.Get(pair.Key()); !ok { - paths.PathItems.Set(path, item) + + // Helper function to update or set path items + addPathItem := func(path string, newItem *v3.PathItem) { + if existing, ok := paths.PathItems.Get(path); !ok { + paths.PathItems.Set(path, newItem) } else { - if item.Get != nil { - existing.Get = item.Get - } - if item.Post != nil { - existing.Post = item.Post - } - if item.Delete != nil { - existing.Delete = item.Delete - } - if item.Put != nil { - existing.Put = item.Put - } - if item.Patch != nil { - existing.Patch = item.Patch - } + mergePathItems(existing, newItem) paths.PathItems.Set(path, existing) } } - // No google.api annotations for this method, so default to the ConnectRPC/gRPC path + + // Update path items from google.api annotations + for pair := pathItems.First(); pair != nil; pair = pair.Next() { + addPathItem(pair.Key(), pair.Value()) + } + + // Default to ConnectRPC/gRPC path if no google.api annotations if pathItems == nil || pathItems.Len() == 0 { - paths.PathItems.Set("/"+string(service.FullName())+"/"+string(method.Name()), methodToPathItem(opts, method)) + path := "/" + string(service.FullName()) + "/" + string(method.Name()) + addPathItem(path, methodToPathItem(opts, method)) } } } @@ -54,6 +48,173 @@ func addPathItemsFromFile(opts options.Options, fd protoreflect.FileDescriptor, return nil } +func mergePathItems(existing, new *v3.PathItem) { + // Merge operations + operations := []struct { + existingOp **v3.Operation + newOp *v3.Operation + }{ + {&existing.Get, new.Get}, + {&existing.Post, new.Post}, + {&existing.Put, new.Put}, + {&existing.Delete, new.Delete}, + {&existing.Options, new.Options}, + {&existing.Head, new.Head}, + {&existing.Patch, new.Patch}, + {&existing.Trace, new.Trace}, + } + + for _, op := range operations { + if op.newOp != nil { + mergeOperation(op.existingOp, op.newOp) + } + } + + // Merge other fields + if new.Summary != "" { + existing.Summary = new.Summary + } + if new.Description != "" { + existing.Description = new.Description + } + existing.Servers = append(existing.Servers, new.Servers...) + existing.Parameters = append(existing.Parameters, new.Parameters...) + + // Merge extensions + for pair := new.Extensions.First(); pair != nil; pair = pair.Next() { + if _, ok := existing.Extensions.Get(pair.Key()); !ok { + existing.Extensions.Set(pair.Key(), pair.Value()) + } + } +} + +func mergeOperation(existing **v3.Operation, new *v3.Operation) { + if *existing == nil { + *existing = new + return + } + // Merge operation fields + if new.Summary != "" { + (*existing).Summary = new.Summary + } + if new.Description != "" { + (*existing).Description = new.Description + } + (*existing).Tags = append((*existing).Tags, new.Tags...) + (*existing).Parameters = append((*existing).Parameters, new.Parameters...) + if new.RequestBody != nil { + (*existing).RequestBody = new.RequestBody + } + if new.Responses != nil { + mergeResponses((*existing).Responses, new.Responses) + } + if new.Deprecated != nil { + (*existing).Deprecated = new.Deprecated + } + + // Add support for additional Operation fields + if new.Callbacks != nil { + if (*existing).Callbacks == nil { + (*existing).Callbacks = orderedmap.New[string, *v3.Callback]() + } + for pair := new.Callbacks.First(); pair != nil; pair = pair.Next() { + if _, ok := (*existing).Callbacks.Get(pair.Key()); !ok { + (*existing).Callbacks.Set(pair.Key(), pair.Value()) + } + } + } + + if new.Security != nil { + (*existing).Security = append((*existing).Security, new.Security...) + } + + if new.Servers != nil { + (*existing).Servers = append((*existing).Servers, new.Servers...) + } + + if new.ExternalDocs != nil { + (*existing).ExternalDocs = new.ExternalDocs + } + + // Merge extensions + for pair := new.Extensions.First(); pair != nil; pair = pair.Next() { + if _, ok := (*existing).Extensions.Get(pair.Key()); !ok { + (*existing).Extensions.Set(pair.Key(), pair.Value()) + } + } +} + +func mergeResponses(existing, new *v3.Responses) { + if existing == nil || new == nil { + return + } + + // Merge response codes + for pair := new.Codes.First(); pair != nil; pair = pair.Next() { + code := pair.Key() + if existingResponse, ok := existing.Codes.Get(code); !ok { + existing.Codes.Set(code, pair.Value()) + } else { + mergeResponse(existingResponse, pair.Value()) + } + } + + // Merge default response + if new.Default != nil { + if existing.Default == nil { + existing.Default = new.Default + } else { + mergeResponse(existing.Default, new.Default) + } + } +} + +func mergeResponse(existing, new *v3.Response) { + if new.Description != "" { + existing.Description = new.Description + } + + // Merge Content + for pair := new.Content.First(); pair != nil; pair = pair.Next() { + contentType := pair.Key() + mediaType := pair.Value() + if _, ok := existing.Content.Get(contentType); !ok { + existing.Content.Set(contentType, mediaType) + } + } + + // Merge Headers + if new.Headers != nil { + if existing.Headers == nil { + existing.Headers = orderedmap.New[string, *v3.Header]() + } + for pair := new.Headers.First(); pair != nil; pair = pair.Next() { + if _, ok := existing.Headers.Get(pair.Key()); !ok { + existing.Headers.Set(pair.Key(), pair.Value()) + } + } + } + + // Merge Links + if new.Links != nil { + if existing.Links == nil { + existing.Links = orderedmap.New[string, *v3.Link]() + } + for pair := new.Links.First(); pair != nil; pair = pair.Next() { + if _, ok := existing.Links.Get(pair.Key()); !ok { + existing.Links.Set(pair.Key(), pair.Value()) + } + } + } + + // Merge Extensions + for pair := new.Extensions.First(); pair != nil; pair = pair.Next() { + if _, ok := existing.Extensions.Get(pair.Key()); !ok { + existing.Extensions.Set(pair.Key(), pair.Value()) + } + } +} + func methodToOperaton(opts options.Options, method protoreflect.MethodDescriptor, returnGet bool) *v3.Operation { fd := method.ParentFile() service := method.Parent().(protoreflect.ServiceDescriptor)