From b07512eba50b261d583ac4e96baff2af788583b8 Mon Sep 17 00:00:00 2001 From: Richard Carson Derr Date: Thu, 22 Aug 2024 22:46:16 -0400 Subject: [PATCH 01/12] refactor(issue-248): rename endpoint type and change openapi method signature --- rest/endpoint/endpoint.go | 273 ++++++++--------------- rest/endpoint/endpoint_test.go | 252 ++++----------------- rest/endpoint/openapi.go | 156 ------------- rest/endpoint/openapi_test.go | 388 ++++----------------------------- 4 files changed, 179 insertions(+), 890 deletions(-) delete mode 100644 rest/endpoint/openapi.go diff --git a/rest/endpoint/endpoint.go b/rest/endpoint/endpoint.go index acf3dca..dfe077b 100644 --- a/rest/endpoint/endpoint.go +++ b/rest/endpoint/endpoint.go @@ -12,9 +12,7 @@ import ( "fmt" "io" "net/http" - "reflect" "strconv" - "strings" "github.com/z5labs/bedrock/pkg/ptr" @@ -47,26 +45,22 @@ type options struct { method string pattern string + pathParams map[PathParam]struct{} + headerParams map[Header]struct{} + queryParams map[QueryParam]struct{} + defaultStatusCode int validators []func(*http.Request) error errHandler ErrorHandler - schemas map[string]*openapi3.Schema - pathParams []*openapi3.Parameter - headers []*openapi3.Parameter - queryParams []*openapi3.Parameter - request *openapi3.RequestBody - responses *openapi3.Responses + openapi *openapi3.Operation } // Option type Option func(*options) -// Endpoint -type Endpoint[Req, Resp any] struct { - method string - pattern string - +// Operation +type Operation[Req, Resp any] struct { validators []func(*http.Request) error injectors []injector @@ -75,7 +69,7 @@ type Endpoint[Req, Resp any] struct { errHandler ErrorHandler - openapi func(*openapi3.Spec) + openapi *openapi3.Operation } const DefaultStatusCode = http.StatusOK @@ -87,48 +81,29 @@ func StatusCode(statusCode int) Option { } } -type pathParam struct { - name string -} - -func parsePathParams(s string) []pathParam { - var params []pathParam - var found bool - for { - if len(s) == 0 { - return params - } - - _, s, found = strings.Cut(s, "{") - if !found { - return params - } - - i := strings.IndexByte(s, '}') - if i == -1 { - return params - } - - param := s[:i] - s = s[i:] - - name := strings.TrimSuffix(param, ".") - params = append(params, pathParam{ - name: name, - }) - } +// PathParam +type PathParam struct { + Name string + Pattern string + Required bool } -func pathParams(ps ...pathParam) Option { +// PathParams +func PathParams(ps ...PathParam) Option { return func(o *options) { for _, p := range ps { - o.pathParams = append(o.pathParams, &openapi3.Parameter{ - In: openapi3.ParameterInPath, - Name: p.name, - Required: ptr.Ref(true), - Schema: &openapi3.SchemaOrRef{ - Schema: &openapi3.Schema{ - Type: ptr.Ref(openapi3.SchemaTypeString), + o.pathParams[p] = struct{}{} + + o.openapi.Parameters = append(o.openapi.Parameters, openapi3.ParameterOrRef{ + Parameter: &openapi3.Parameter{ + In: openapi3.ParameterInPath, + Name: p.Name, + Required: ptr.Ref(p.Required), + Schema: &openapi3.SchemaOrRef{ + Schema: &openapi3.Schema{ + Type: ptr.Ref(openapi3.SchemaTypeString), + Pattern: ptr.Ref(p.Pattern), + }, }, }, }) @@ -147,13 +122,18 @@ type Header struct { func Headers(hs ...Header) Option { return func(o *options) { for _, h := range hs { - o.headers = append(o.headers, &openapi3.Parameter{ - In: openapi3.ParameterInHeader, - Name: h.Name, - Required: ptr.Ref(h.Required), - Schema: &openapi3.SchemaOrRef{ - Schema: &openapi3.Schema{ - Type: ptr.Ref(openapi3.SchemaTypeString), + o.headerParams[h] = struct{}{} + + o.openapi.Parameters = append(o.openapi.Parameters, openapi3.ParameterOrRef{ + Parameter: &openapi3.Parameter{ + In: openapi3.ParameterInHeader, + Name: h.Name, + Required: ptr.Ref(h.Required), + Schema: &openapi3.SchemaOrRef{ + Schema: &openapi3.Schema{ + Type: ptr.Ref(openapi3.SchemaTypeString), + Pattern: ptr.Ref(h.Pattern), + }, }, }, }) @@ -174,13 +154,18 @@ type QueryParam struct { func QueryParams(qps ...QueryParam) Option { return func(o *options) { for _, qp := range qps { - o.queryParams = append(o.queryParams, &openapi3.Parameter{ - In: openapi3.ParameterInQuery, - Name: qp.Name, - Required: ptr.Ref(qp.Required), - Schema: &openapi3.SchemaOrRef{ - Schema: &openapi3.Schema{ - Type: ptr.Ref(openapi3.SchemaTypeString), + o.queryParams[qp] = struct{}{} + + o.openapi.Parameters = append(o.openapi.Parameters, openapi3.ParameterOrRef{ + Parameter: &openapi3.Parameter{ + In: openapi3.ParameterInQuery, + Name: qp.Name, + Required: ptr.Ref(qp.Required), + Schema: &openapi3.SchemaOrRef{ + Schema: &openapi3.Schema{ + Type: ptr.Ref(openapi3.SchemaTypeString), + Pattern: ptr.Ref(qp.Pattern), + }, }, }, }) @@ -198,10 +183,6 @@ type ContentTyper interface { // Accepts func Accepts[Req any]() Option { return func(o *options) { - if o.request == nil { - o.request = new(openapi3.RequestBody) - } - contentType := "" var req Req @@ -224,18 +205,12 @@ func Accepts[Req any]() Option { var schemaOrRef openapi3.SchemaOrRef schemaOrRef.FromJSONSchema(schema.ToSchemaOrBool()) - typeName := reflect.TypeOf(req).Name() - schemaRef := fmt.Sprintf("#/components/schemas/%s", typeName) - o.schemas[typeName] = schemaOrRef.Schema - - o.request = &openapi3.RequestBody{ - Required: ptr.Ref(true), - Content: map[string]openapi3.MediaType{ - contentType: { - Schema: &openapi3.SchemaOrRef{ - SchemaReference: &openapi3.SchemaReference{ - Ref: schemaRef, - }, + o.openapi.RequestBody = &openapi3.RequestBodyOrRef{ + RequestBody: &openapi3.RequestBody{ + Required: ptr.Ref(true), + Content: map[string]openapi3.MediaType{ + contentType: { + Schema: &schemaOrRef, }, }, }, @@ -246,13 +221,7 @@ func Accepts[Req any]() Option { // Returns func Returns(status int) Option { return func(o *options) { - if o.responses == nil { - o.responses = &openapi3.Responses{ - MapOfResponseOrRefValues: make(map[string]openapi3.ResponseOrRef), - } - } - - o.responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ + o.openapi.Responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ Response: &openapi3.Response{}, } } @@ -261,19 +230,10 @@ func Returns(status int) Option { // ReturnsWith func ReturnsWith[Resp any](status int) Option { return func(o *options) { - if o.responses == nil { - o.responses = &openapi3.Responses{ - MapOfResponseOrRefValues: make(map[string]openapi3.ResponseOrRef), - } - } - - contentType := "" - var resp Resp - if ct, ok := any(resp).(ContentTyper); ok { - contentType = ct.ContentType() - } else { - o.responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ + ct, ok := any(resp).(ContentTyper) + if !ok { + o.openapi.Responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ Response: &openapi3.Response{}, } return @@ -288,19 +248,11 @@ func ReturnsWith[Resp any](status int) Option { var schemaOrRef openapi3.SchemaOrRef schemaOrRef.FromJSONSchema(schema.ToSchemaOrBool()) - typeName := reflect.TypeOf(resp).Name() - schemaRef := fmt.Sprintf("#/components/schemas/%s", typeName) - o.schemas[typeName] = schemaOrRef.Schema - - o.responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ + o.openapi.Responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ Response: &openapi3.Response{ Content: map[string]openapi3.MediaType{ - contentType: { - Schema: &openapi3.SchemaOrRef{ - SchemaReference: &openapi3.SchemaReference{ - Ref: schemaRef, - }, - }, + ct.ContentType(): { + Schema: &schemaOrRef, }, }, }, @@ -325,40 +277,37 @@ func (f errorHandlerFunc) HandleError(w http.ResponseWriter, err error) { const DefaultErrorStatusCode = http.StatusInternalServerError // New initializes an Endpoint. -func New[Req, Resp any](method string, pattern string, handler Handler[Req, Resp], opts ...Option) *Endpoint[Req, Resp] { +func New[Req, Resp any](handler Handler[Req, Resp], opts ...Option) *Operation[Req, Resp] { o := &options{ - method: method, - pattern: pattern, defaultStatusCode: DefaultStatusCode, - validators: []func(*http.Request) error{ - validateMethod(method), - }, + pathParams: make(map[PathParam]struct{}), + headerParams: make(map[Header]struct{}), + queryParams: make(map[QueryParam]struct{}), errHandler: errorHandlerFunc(func(w http.ResponseWriter, err error) { w.WriteHeader(DefaultErrorStatusCode) }), - schemas: make(map[string]*openapi3.Schema), + openapi: &openapi3.Operation{ + Responses: openapi3.Responses{ + MapOfResponseOrRefValues: make(map[string]openapi3.ResponseOrRef), + }, + }, } - for _, opt := range withBuiltinOptions[Req, Resp](pattern, opts...) { + for _, opt := range withBuiltinOptions[Req, Resp](opts...) { opt(o) } - return &Endpoint[Req, Resp]{ - method: method, - pattern: pattern, + return &Operation[Req, Resp]{ injectors: initInjectors(o), validators: o.validators, statusCode: o.defaultStatusCode, handler: handler, errHandler: o.errHandler, - openapi: setOpenApiSpec(o), + openapi: o.openapi, } } -func withBuiltinOptions[Req, Resp any](pattern string, opts ...Option) []Option { - parsedPathParams := parsePathParams(pattern) - opts = append(opts, pathParams(parsedPathParams...)) - +func withBuiltinOptions[Req, Resp any](opts ...Option) []Option { var req Req if _, ok := any(req).(ContentTyper); ok { opts = append(opts, Accepts[Req]()) @@ -380,10 +329,10 @@ func withBuiltinOptions[Req, Resp any](pattern string, opts ...Option) []Option func initInjectors(o *options) []injector { injectors := []injector{injectResponseHeaders} - for _, p := range o.pathParams { + for p := range o.pathParams { injectors = append(injectors, injectPathParam(p.Name)) } - if len(o.headers) > 0 { + if len(o.headerParams) > 0 { injectors = append(injectors, injectHeaders) } if len(o.queryParams) > 0 { @@ -392,80 +341,48 @@ func initInjectors(o *options) []injector { return injectors } -// Get returns an Endpoint configured for handling HTTP GET requests. -func Get[Req, Resp any](pattern string, handler Handler[Req, Resp], opts ...Option) *Endpoint[Req, Resp] { - return New(http.MethodGet, pattern, handler, opts...) -} - -// Post returns an Endpoint configured for handling HTTP POST requests. -func Post[Req, Resp any](pattern string, handler Handler[Req, Resp], opts ...Option) *Endpoint[Req, Resp] { - return New(http.MethodPost, pattern, handler, opts...) -} - -// Put returns an Endpoint configured for handling HTTP PUT requests. -func Put[Req, Resp any](pattern string, handler Handler[Req, Resp], opts ...Option) *Endpoint[Req, Resp] { - return New(http.MethodPut, pattern, handler, opts...) -} - -// Delete returns an Endpoint configured for handling HTTP DELETE requests. -func Delete[Req, Resp any](pattern string, handler Handler[Req, Resp], opts ...Option) *Endpoint[Req, Resp] { - return New(http.MethodDelete, pattern, handler, opts...) -} - -// Method returns the HTTP method which this endpoint -// is configured to handle requests for. -func (e *Endpoint[Req, Resp]) Method() string { - return e.method -} - -// Pattern returns HTTP path pattern for this endpoint. -func (e *Endpoint[Req, Resp]) Pattern() string { - return e.pattern -} - -// OpenApi allows the endpoint to register itself with an OpenAPI spec. -func (e *Endpoint[Req, Resp]) OpenApi(spec *openapi3.Spec) { - e.openapi(spec) +func (op *Operation[Req, Resp]) OpenApi() *openapi3.Operation { + return op.openapi } // ServeHTTP implements the [http.Handler] interface. -func (e *Endpoint[Req, Resp]) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := inject(r.Context(), w, r, e.injectors...) +func (op *Operation[Req, Resp]) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := inject(r.Context(), w, r, op.injectors...) - err := validateRequest(r, e.validators...) + err := validateRequest(r, op.validators...) if err != nil { - e.handleError(w, r, err) + op.handleError(w, r, err) return } var req Req err = unmarshal(r.Body, &req) if err != nil { - e.handleError(w, r, err) + op.handleError(w, r, err) return } err = validate(req) if err != nil { - e.handleError(w, r, err) + op.handleError(w, r, err) return } - resp, err := e.handler.Handle(ctx, req) + resp, err := op.handler.Handle(ctx, req) if err != nil { - e.handleError(w, r, err) + op.handleError(w, r, err) return } bm, ok := any(resp).(encoding.BinaryMarshaler) if !ok { - w.WriteHeader(e.statusCode) + w.WriteHeader(op.statusCode) return } b, err := bm.MarshalBinary() if err != nil { - e.handleError(w, r, err) + op.handleError(w, r, err) return } @@ -473,21 +390,21 @@ func (e *Endpoint[Req, Resp]) ServeHTTP(w http.ResponseWriter, r *http.Request) w.Header().Set("Content-Type", ct.ContentType()) } - w.WriteHeader(e.statusCode) + w.WriteHeader(op.statusCode) _, err = io.Copy(w, bytes.NewReader(b)) if err != nil { - e.handleError(w, r, err) + op.handleError(w, r, err) return } } -func (e *Endpoint[Req, Resp]) handleError(w http.ResponseWriter, r *http.Request, err error) { +func (op *Operation[Req, Resp]) handleError(w http.ResponseWriter, r *http.Request, err error) { if h, ok := err.(http.Handler); ok { h.ServeHTTP(w, r) return } - e.errHandler.HandleError(w, err) + op.errHandler.HandleError(w, err) } func unmarshal[Req any](r io.ReadCloser, req *Req) error { diff --git a/rest/endpoint/endpoint_test.go b/rest/endpoint/endpoint_test.go index 3a63c96..073687f 100644 --- a/rest/endpoint/endpoint_test.go +++ b/rest/endpoint/endpoint_test.go @@ -80,110 +80,13 @@ func (FailMarshalBinary) MarshalBinary() ([]byte, error) { return nil, errMarshalBinary } -func TestGet(t *testing.T) { - t.Run("will return 405 http status code", func(t *testing.T) { - t.Run("if a non GET method is used to call its returned endpoint", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, - noopHandler{}, - ) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, pattern, nil) - - e.ServeHTTP(w, r) - - resp := w.Result() - if !assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) { - return - } - }) - }) -} - -func TestPost(t *testing.T) { - t.Run("will return 405 http status code", func(t *testing.T) { - t.Run("if a non POST method is used to call its returned endpoint", func(t *testing.T) { - pattern := "/" - - e := Post( - pattern, - noopHandler{}, - ) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) - - e.ServeHTTP(w, r) - - resp := w.Result() - if !assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) { - return - } - }) - }) -} - -func TestPut(t *testing.T) { - t.Run("will return 405 http status code", func(t *testing.T) { - t.Run("if a non PUT method is used to call its returned endpoint", func(t *testing.T) { - pattern := "/" - - e := Put( - pattern, - noopHandler{}, - ) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) - - e.ServeHTTP(w, r) - - resp := w.Result() - if !assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) { - return - } - }) - }) -} - -func TestDelete(t *testing.T) { - t.Run("will return 405 http status code", func(t *testing.T) { - t.Run("if a non DELETE method is used to call its returned endpoint", func(t *testing.T) { - pattern := "/" - - e := Delete( - pattern, - noopHandler{}, - ) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) - - e.ServeHTTP(w, r) - - resp := w.Result() - if !assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) { - return - } - }) - }) -} - func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will return the default success http status code", func(t *testing.T) { t.Run("if the underlying Handler succeeds with an empty response", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, - noopHandler{}, - ) + e := New(noopHandler{}) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -194,17 +97,14 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the underlying Handler succeeds with a encoding.BinaryMarshaler response", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := New( HandlerFunc[Empty, JsonContent](func(_ context.Context, _ Empty) (JsonContent, error) { return JsonContent{Value: "hello, world"}, nil }), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -231,14 +131,14 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will inject path params", func(t *testing.T) { t.Run("if a valid http.ServeMux path param pattern is used", func(t *testing.T) { - pattern := "/{id}" - - e := Get( - pattern, + e := New( HandlerFunc[Empty, JsonContent](func(ctx context.Context, _ Empty) (JsonContent, error) { v := PathValue(ctx, "id") return JsonContent{Value: v}, nil }), + PathParams(PathParam{ + Name: "id", + }), ) w := httptest.NewRecorder() @@ -247,7 +147,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { // for path params a http.ServeMux must be used since // Endpoint doesn't support it directly mux := http.NewServeMux() - mux.Handle(pattern, e) + mux.Handle("GET /{id}", e) mux.ServeHTTP(w, r) resp := w.Result() @@ -273,10 +173,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will inject headers", func(t *testing.T) { t.Run("if a header is configured with the Headers option", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := New( HandlerFunc[Empty, JsonContent](func(ctx context.Context, _ Empty) (JsonContent, error) { v := HeaderValue(ctx, "test-header") return JsonContent{Value: v}, nil @@ -287,7 +184,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("test-header", "hello, world") e.ServeHTTP(w, r) @@ -315,10 +212,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will inject query params", func(t *testing.T) { t.Run("if a query param is configured with the QueryParams option", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := New( HandlerFunc[Empty, JsonContent](func(ctx context.Context, _ Empty) (JsonContent, error) { v := QueryValue(ctx, "test-query") return JsonContent{Value: v}, nil @@ -329,7 +223,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern+"?test-query=abc123", nil) + r := httptest.NewRequest(http.MethodGet, "/?test-query=abc123", nil) e.ServeHTTP(w, r) @@ -356,20 +250,18 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will return custom success http status code", func(t *testing.T) { t.Run("if the StatusCode option is used and the underlying Handler succeeds with an empty response", func(t *testing.T) { - pattern := "/" statusCode := http.StatusCreated if !assert.NotEqual(t, DefaultStatusCode, statusCode) { return } - e := Get( - pattern, + e := New( noopHandler{}, StatusCode(statusCode), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -380,14 +272,12 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the StatusCode option is used and the underlying Handler succeeds with a encoding.BinaryMarshaler response", func(t *testing.T) { - pattern := "/" statusCode := http.StatusCreated if !assert.NotEqual(t, DefaultStatusCode, statusCode) { return } - e := Get( - pattern, + e := New( HandlerFunc[Empty, JsonContent](func(_ context.Context, _ Empty) (JsonContent, error) { return JsonContent{Value: "hello, world"}, nil }), @@ -395,7 +285,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -422,17 +312,14 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will return non-success http status code", func(t *testing.T) { t.Run("if the underlying Handler returns an error", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := New( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, errors.New("failed") }), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -443,14 +330,12 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if a custom error handler is set", func(t *testing.T) { - pattern := "/" errStatusCode := http.StatusServiceUnavailable if !assert.NotEqual(t, DefaultErrorStatusCode, errStatusCode) { return } - e := Get( - pattern, + e := New( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, errors.New("failed") }), @@ -460,7 +345,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -471,21 +356,19 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the underlying error implements http.Handler", func(t *testing.T) { - pattern := "/" errStatusCode := http.StatusServiceUnavailable if !assert.NotEqual(t, DefaultErrorStatusCode, errStatusCode) { return } - e := Get( - pattern, + e := New( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, httpError{status: errStatusCode} }), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -495,30 +378,8 @@ func TestEndpoint_ServeHTTP(t *testing.T) { } }) - t.Run("if the http request is for the wrong http method", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, - noopHandler{}, - ) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, pattern, nil) - - e.ServeHTTP(w, r) - - resp := w.Result() - if !assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) { - return - } - }) - t.Run("if a required http header is missing", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := New( noopHandler{}, Headers( Header{ @@ -529,7 +390,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -540,10 +401,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if a http header does not match its expected pattern", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := New( noopHandler{}, Headers( Header{ @@ -554,7 +412,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("Authorization", "abc123") e.ServeHTTP(w, r) @@ -566,10 +424,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if a required query param is missing", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := New( noopHandler{}, QueryParams( QueryParam{ @@ -580,7 +435,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -591,10 +446,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if a query param does not match its expected pattern", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := New( noopHandler{}, QueryParams( QueryParam{ @@ -605,7 +457,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern+"?test=abc123", nil) + r := httptest.NewRequest(http.MethodGet, "/?test=abc123", nil) e.ServeHTTP(w, r) @@ -616,17 +468,14 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the request content type header does not match the content type from ContentTyper", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := New( HandlerFunc[JsonContent, Empty](func(_ context.Context, _ JsonContent) (Empty, error) { return Empty{}, nil }), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("Content-Type", "application/xml") e.ServeHTTP(w, r) @@ -638,11 +487,8 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the request body fails to unmarshal", func(t *testing.T) { - pattern := "/" - var caughtError error - e := Post( - pattern, + e := New( HandlerFunc[FailUnmarshalBinary, Empty](func(_ context.Context, _ FailUnmarshalBinary) (Empty, error) { return Empty{}, nil }), @@ -654,7 +500,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, pattern, strings.NewReader(`{}`)) + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`)) e.ServeHTTP(w, r) @@ -668,11 +514,8 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the unmarshaled request body is invalid", func(t *testing.T) { - pattern := "/" - var caughtError error - e := Post( - pattern, + e := New( HandlerFunc[InvalidRequest, Empty](func(_ context.Context, _ InvalidRequest) (Empty, error) { return Empty{}, nil }), @@ -684,7 +527,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, pattern, strings.NewReader(`{}`)) + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`)) e.ServeHTTP(w, r) @@ -698,11 +541,8 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the response body fails to marshal itself to binary", func(t *testing.T) { - pattern := "/" - var caughtError error - e := Post( - pattern, + e := New( HandlerFunc[Empty, FailMarshalBinary](func(_ context.Context, _ Empty) (FailMarshalBinary, error) { return FailMarshalBinary{}, nil }), @@ -714,7 +554,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, pattern, strings.NewReader(`{}`)) + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`)) e.ServeHTTP(w, r) @@ -730,17 +570,14 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will return response header", func(t *testing.T) { t.Run("if the response body implements ContentTyper", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := New( HandlerFunc[Empty, JsonContent](func(_ context.Context, _ Empty) (JsonContent, error) { return JsonContent{Value: "hello, world"}, nil }), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -768,10 +605,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the underlying Handler sets a custom response header using the context", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := New( HandlerFunc[Empty, Empty](func(ctx context.Context, _ Empty) (Empty, error) { SetResponseHeader(ctx, "Content-Type", "test-content-type") return Empty{}, nil @@ -779,7 +613,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) diff --git a/rest/endpoint/openapi.go b/rest/endpoint/openapi.go deleted file mode 100644 index 299cecf..0000000 --- a/rest/endpoint/openapi.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) 2024 Z5Labs and Contributors -// -// This software is released under the MIT License. -// https://opensource.org/licenses/MIT - -package endpoint - -import ( - "github.com/swaggest/openapi-go/openapi3" -) - -func setOpenApiSpec(o *options) func(*openapi3.Spec) { - return compose( - addSchemas(o.schemas), - addParameters(o.method, o.pattern, o.pathParams...), - addParameters(o.method, o.pattern, o.headers...), - addParameters(o.method, o.pattern, o.queryParams...), - addRequestBody(o.method, o.pattern, o.request), - addResponses(o.method, o.pattern, o.responses), - ) -} - -func compose(fs ...func(*openapi3.Spec)) func(*openapi3.Spec) { - return func(s *openapi3.Spec) { - for _, f := range fs { - f(s) - } - } -} - -func addSchemas(schemas map[string]*openapi3.Schema) func(*openapi3.Spec) { - return func(s *openapi3.Spec) { - if len(schemas) == 0 { - return - } - - if s.Components == nil { - s.Components = new(openapi3.Components) - } - - comps := s.Components - if comps.Schemas == nil { - comps.Schemas = new(openapi3.ComponentsSchemas) - } - - compSchemas := comps.Schemas - if compSchemas.MapOfSchemaOrRefValues == nil { - compSchemas.MapOfSchemaOrRefValues = make(map[string]openapi3.SchemaOrRef, len(schemas)) - } - - for name, schema := range schemas { - compSchemas.MapOfSchemaOrRefValues[name] = openapi3.SchemaOrRef{ - Schema: schema, - } - } - } -} - -func addParameters(method, pattern string, params ...*openapi3.Parameter) func(*openapi3.Spec) { - return func(s *openapi3.Spec) { - if len(params) == 0 { - return - } - - if s.Paths.MapOfPathItemValues == nil { - s.Paths.MapOfPathItemValues = make(map[string]openapi3.PathItem) - } - - pathItemVals := s.Paths.MapOfPathItemValues - pathItem, ok := pathItemVals[pattern] - if !ok { - pathItem = openapi3.PathItem{ - MapOfOperationValues: make(map[string]openapi3.Operation), - } - } - - opVals := pathItem.MapOfOperationValues - opVal, ok := opVals[method] - if !ok { - opVal = openapi3.Operation{} - } - - for _, param := range params { - opVal.Parameters = append(opVal.Parameters, openapi3.ParameterOrRef{ - Parameter: param, - }) - } - - opVals[method] = opVal - pathItemVals[pattern] = pathItem - } -} - -func addRequestBody(method, pattern string, reqBody *openapi3.RequestBody) func(*openapi3.Spec) { - return func(s *openapi3.Spec) { - if reqBody == nil { - return - } - - if s.Paths.MapOfPathItemValues == nil { - s.Paths.MapOfPathItemValues = make(map[string]openapi3.PathItem) - } - - pathItemVals := s.Paths.MapOfPathItemValues - pathItem, ok := pathItemVals[pattern] - if !ok { - pathItem = openapi3.PathItem{ - MapOfOperationValues: make(map[string]openapi3.Operation), - } - } - - opVals := pathItem.MapOfOperationValues - opVal, ok := opVals[method] - if !ok { - opVal = openapi3.Operation{} - } - - opVal.RequestBody = &openapi3.RequestBodyOrRef{ - RequestBody: reqBody, - } - - opVals[method] = opVal - pathItemVals[pattern] = pathItem - } -} - -func addResponses(method, pattern string, responses *openapi3.Responses) func(*openapi3.Spec) { - return func(s *openapi3.Spec) { - if responses == nil { - return - } - - if s.Paths.MapOfPathItemValues == nil { - s.Paths.MapOfPathItemValues = make(map[string]openapi3.PathItem) - } - - pathItemVals := s.Paths.MapOfPathItemValues - pathItem, ok := pathItemVals[pattern] - if !ok { - pathItem = openapi3.PathItem{ - MapOfOperationValues: make(map[string]openapi3.Operation), - } - } - - opVals := pathItem.MapOfOperationValues - opVal, ok := opVals[method] - if !ok { - opVal = openapi3.Operation{} - } - - opVal.Responses = *responses - - opVals[method] = opVal - pathItemVals[pattern] = pathItem - } -} diff --git a/rest/endpoint/openapi_test.go b/rest/endpoint/openapi_test.go index cb2b7da..3cbd45f 100644 --- a/rest/endpoint/openapi_test.go +++ b/rest/endpoint/openapi_test.go @@ -9,9 +9,7 @@ import ( "context" "encoding/json" "net/http" - "path" "strconv" - "strings" "testing" "github.com/z5labs/bedrock/pkg/ptr" @@ -23,50 +21,27 @@ import ( func TestEndpoint_OpenApi(t *testing.T) { t.Run("will required path parameter", func(t *testing.T) { t.Run("if a http.ServeMux path parameter pattern is used", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/{id}" - e := New( - method, - pattern, HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), + PathParams(PathParam{ + Name: "id", + Required: true, + }), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - op := ops[method] params := op.Parameters if !assert.Len(t, params, 1) { return @@ -90,55 +65,29 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set non-required header parameter", func(t *testing.T) { t.Run("if a header is provided with the Headers option", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" header := Header{ Name: "MyHeader", Required: true, } e := New( - method, - pattern, HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), Headers(header), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - op := ops[method] params := op.Parameters if !assert.Len(t, params, 1) { return @@ -162,55 +111,29 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set required header parameter", func(t *testing.T) { t.Run("if a header is provided with the Headers option", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" header := Header{ Name: "MyHeader", Required: true, } e := New( - method, - pattern, HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), Headers(header), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - op := ops[method] params := op.Parameters if !assert.Len(t, params, 1) { return @@ -234,54 +157,28 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set non-required query param", func(t *testing.T) { t.Run("if a query param is provided with the QueryParams option", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" queryParam := QueryParam{ Name: "myparam", } e := New( - method, - pattern, HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), QueryParams(queryParam), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - op := ops[method] params := op.Parameters if !assert.Len(t, params, 1) { return @@ -305,55 +202,29 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set required query param", func(t *testing.T) { t.Run("if a query param is provided with the QueryParams option", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" queryParam := QueryParam{ Name: "myparam", Required: true, } e := New( - method, - pattern, HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), QueryParams(queryParam), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - op := ops[method] params := op.Parameters if !assert.Len(t, params, 1) { return @@ -377,50 +248,24 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set request body type", func(t *testing.T) { t.Run("if the request type implements ContentTyper interface", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" - e := New( - method, - pattern, HandlerFunc[JsonContent, Empty](func(_ context.Context, _ JsonContent) (Empty, error) { return Empty{}, nil }), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - reqBodyOrRef := ops[method].RequestBody + reqBodyOrRef := op.RequestBody if !assert.NotNil(t, reqBodyOrRef) { return } @@ -443,31 +288,7 @@ func TestEndpoint_OpenApi(t *testing.T) { return } - schemaRef := schemaOrRef.SchemaReference - if !assert.NotNil(t, schemaRef) { - return - } - _, schemaRefName := path.Split(schemaRef.Ref) - - comps := spec.Components - if !assert.NotNil(t, comps) { - return - } - - schemas := comps.Schemas - if !assert.NotNil(t, schemas) { - return - } - - schemaOrRefValues := schemas.MapOfSchemaOrRefValues - if !assert.Len(t, schemaOrRefValues, 1) { - return - } - if !assert.Contains(t, schemaOrRefValues, schemaRefName) { - return - } - - schema := schemaOrRefValues[schemaRefName].Schema + schema := schemaOrRef.Schema if !assert.NotNil(t, schema) { return } @@ -484,50 +305,24 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set response body type", func(t *testing.T) { t.Run("if the response type implements ContentTyper interface", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" - e := New( - method, - pattern, HandlerFunc[Empty, JsonContent](func(_ context.Context, _ Empty) (JsonContent, error) { return JsonContent{}, nil }), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - respOrRefValues := ops[method].Responses.MapOfResponseOrRefValues + respOrRefValues := op.Responses.MapOfResponseOrRefValues if !assert.Len(t, respOrRefValues, 1) { return } @@ -553,31 +348,7 @@ func TestEndpoint_OpenApi(t *testing.T) { return } - schemaRef := schemaOrRef.SchemaReference - if !assert.NotNil(t, schemaRef) { - return - } - _, respRefName := path.Split(schemaRef.Ref) - - comps := spec.Components - if !assert.NotNil(t, comps) { - return - } - - schemas := comps.Schemas - if !assert.NotNil(t, schemas) { - return - } - - schemaOrRefValues := schemas.MapOfSchemaOrRefValues - if !assert.Len(t, schemaOrRefValues, 1) { - return - } - if !assert.Contains(t, schemaOrRefValues, respRefName) { - return - } - - schema := schemaOrRefValues[respRefName].Schema + schema := schemaOrRef.Schema if !assert.NotNil(t, schema) { return } @@ -594,50 +365,24 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set a empty response body", func(t *testing.T) { t.Run("if the response type does not implement ContentTyper", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" - e := New( - method, - pattern, HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - respOrRefValues := ops[method].Responses.MapOfResponseOrRefValues + respOrRefValues := op.Responses.MapOfResponseOrRefValues if !assert.Len(t, respOrRefValues, 1) { return } @@ -657,52 +402,27 @@ func TestEndpoint_OpenApi(t *testing.T) { }) t.Run("if the Returns option is used with a http status code", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" statusCode := http.StatusBadRequest e := New( - method, - pattern, HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), Returns(statusCode), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - respOrRefValues := ops[method].Responses.MapOfResponseOrRefValues + respOrRefValues := op.Responses.MapOfResponseOrRefValues if !assert.Len(t, respOrRefValues, 2) { return } @@ -733,56 +453,30 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will override default response status code", func(t *testing.T) { t.Run("if DefaultStatusCode option is used", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" - statusCode := http.StatusCreated if !assert.NotEqual(t, statusCode, DefaultStatusCode) { return } e := New( - method, - pattern, HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), StatusCode(statusCode), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - respOrRefValues := ops[method].Responses.MapOfResponseOrRefValues + respOrRefValues := op.Responses.MapOfResponseOrRefValues if !assert.Len(t, respOrRefValues, 1) { return } From c7bcdf2c2f211cf16df22c9acf7c00fc85881ec5 Mon Sep 17 00:00:00 2001 From: Richard Carson Derr Date: Thu, 22 Aug 2024 22:47:25 -0400 Subject: [PATCH 02/12] refactor(issue-248): update rest package and example after endpoint package changes --- example/simple_rest/app/app.go | 7 +++---- rest/rest.go | 14 +++++++------- rest/rest_example_test.go | 5 ++--- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/example/simple_rest/app/app.go b/example/simple_rest/app/app.go index b2d39f8..2e36819 100644 --- a/example/simple_rest/app/app.go +++ b/example/simple_rest/app/app.go @@ -39,10 +39,9 @@ func Init(ctx context.Context, cfg Config) (bedrock.App, error) { restApp := rest.NewApp( rest.ListenOn(cfg.Http.Port), - rest.Handle( - "/echo", - endpoint.Post( - "/echo", + rest.Endpoint( + "POST /echo", + endpoint.New( echoService, endpoint.Headers( endpoint.Header{ diff --git a/rest/rest.go b/rest/rest.go index 7841bf8..6078238 100644 --- a/rest/rest.go +++ b/rest/rest.go @@ -33,23 +33,23 @@ func ListenOn(port uint) Option { } } -// Handler represents anything that can handle HTTP requests +// Operation represents anything that can handle HTTP requests // and provide OpenAPI documentation for itself. -type Handler interface { +type Operation interface { http.Handler - OpenApi(*openapi3.Spec) + OpenApi() *openapi3.Operation } -// Handler registers the provider [Handler] with both +// Endpoint registers the [Operation] with both // the App wide OpenAPI spec and the App wide HTTP server. -func Handle(pattern string, h Handler) Option { +func Endpoint(pattern string, op Operation) Option { return func(app *App) { - h.OpenApi(app.spec) + // h.OpenApi(app.spec) app.mux.Handle( pattern, - otelhttp.WithRouteTag(pattern, h), + otelhttp.WithRouteTag(pattern, op), ) } } diff --git a/rest/rest_example_test.go b/rest/rest_example_test.go index 0916a61..57e8ff3 100644 --- a/rest/rest_example_test.go +++ b/rest/rest_example_test.go @@ -73,10 +73,9 @@ func Example() { app := NewApp( listenOnRandomPort(addrCh), - Handle( + Endpoint( "/", - endpoint.Post( - "/", + endpoint.New( echoService{}, ), ), From da8b46d493d6b5e855a358a329765a6038c14490 Mon Sep 17 00:00:00 2001 From: GitHub Action Date: Fri, 23 Aug 2024 02:48:31 +0000 Subject: [PATCH 03/12] chore(docs): updated coverage badge. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ff0e425..89c8627 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) [![Go Reference](https://pkg.go.dev/badge/github.com/z5labs/bedrock.svg)](https://pkg.go.dev/github.com/z5labs/bedrock) [![Go Report Card](https://goreportcard.com/badge/github.com/z5labs/bedrock)](https://goreportcard.com/report/github.com/z5labs/bedrock) -![Coverage](https://img.shields.io/badge/Coverage-95.0%25-brightgreen) +![Coverage](https://img.shields.io/badge/Coverage-94.1%25-brightgreen) [![build](https://github.com/z5labs/bedrock/actions/workflows/build.yaml/badge.svg)](https://github.com/z5labs/bedrock/actions/workflows/build.yaml) **bedrock provides a minimal, modular and composable foundation for From 6c1d9300be806edde4ca8d71c4e348ac97a6dd03 Mon Sep 17 00:00:00 2001 From: Richard Carson Derr Date: Thu, 22 Aug 2024 22:49:16 -0400 Subject: [PATCH 04/12] refactor(issue-248): remove unnecessary method validation code --- rest/endpoint/validate.go | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/rest/endpoint/validate.go b/rest/endpoint/validate.go index dac808d..eb03a96 100644 --- a/rest/endpoint/validate.go +++ b/rest/endpoint/validate.go @@ -21,31 +21,6 @@ func validateRequest(r *http.Request, validators ...func(*http.Request) error) e return nil } -// InvalidMethodError represents when a request was sent to an endpoint -// for the incorrect method. -type InvalidMethodError struct { - Method string -} - -// Error implements [error] interface. -func (e InvalidMethodError) Error() string { - return fmt.Sprintf("received invalid method for endpoint: %s", e.Method) -} - -// ServeHTTP implements the [http.Handler] interface. -func (InvalidMethodError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusMethodNotAllowed) -} - -func validateMethod(method string) func(*http.Request) error { - return func(r *http.Request) error { - if r.Method == method { - return nil - } - return InvalidMethodError{Method: r.Method} - } -} - // InvalidHeaderError type InvalidHeaderError struct { Header string From e59d51aefa920a23326b86bc20dff1d5bfe8679b Mon Sep 17 00:00:00 2001 From: GitHub Action Date: Fri, 23 Aug 2024 02:49:53 +0000 Subject: [PATCH 05/12] chore(docs): updated coverage badge. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 89c8627..c099d08 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) [![Go Reference](https://pkg.go.dev/badge/github.com/z5labs/bedrock.svg)](https://pkg.go.dev/github.com/z5labs/bedrock) [![Go Report Card](https://goreportcard.com/badge/github.com/z5labs/bedrock)](https://goreportcard.com/report/github.com/z5labs/bedrock) -![Coverage](https://img.shields.io/badge/Coverage-94.1%25-brightgreen) +![Coverage](https://img.shields.io/badge/Coverage-95.1%25-brightgreen) [![build](https://github.com/z5labs/bedrock/actions/workflows/build.yaml/badge.svg)](https://github.com/z5labs/bedrock/actions/workflows/build.yaml) **bedrock provides a minimal, modular and composable foundation for From 626b7a4c6142eeafeed48323dae6878eaff0d848 Mon Sep 17 00:00:00 2001 From: Richard Carson Derr Date: Thu, 22 Aug 2024 23:15:17 -0400 Subject: [PATCH 06/12] refactor(issue-248): force user to handle errors from endpoint operation --- rest/endpoint/endpoint.go | 21 ++---- rest/endpoint/endpoint_test.go | 115 +++++++++++++++++++++------------ rest/endpoint/validate.go | 20 ------ 3 files changed, 81 insertions(+), 75 deletions(-) diff --git a/rest/endpoint/endpoint.go b/rest/endpoint/endpoint.go index dfe077b..3c47003 100644 --- a/rest/endpoint/endpoint.go +++ b/rest/endpoint/endpoint.go @@ -351,26 +351,26 @@ func (op *Operation[Req, Resp]) ServeHTTP(w http.ResponseWriter, r *http.Request err := validateRequest(r, op.validators...) if err != nil { - op.handleError(w, r, err) + op.errHandler.HandleError(w, err) return } var req Req err = unmarshal(r.Body, &req) if err != nil { - op.handleError(w, r, err) + op.errHandler.HandleError(w, err) return } err = validate(req) if err != nil { - op.handleError(w, r, err) + op.errHandler.HandleError(w, err) return } resp, err := op.handler.Handle(ctx, req) if err != nil { - op.handleError(w, r, err) + op.errHandler.HandleError(w, err) return } @@ -382,7 +382,7 @@ func (op *Operation[Req, Resp]) ServeHTTP(w http.ResponseWriter, r *http.Request b, err := bm.MarshalBinary() if err != nil { - op.handleError(w, r, err) + op.errHandler.HandleError(w, err) return } @@ -393,20 +393,11 @@ func (op *Operation[Req, Resp]) ServeHTTP(w http.ResponseWriter, r *http.Request w.WriteHeader(op.statusCode) _, err = io.Copy(w, bytes.NewReader(b)) if err != nil { - op.handleError(w, r, err) + op.errHandler.HandleError(w, err) return } } -func (op *Operation[Req, Resp]) handleError(w http.ResponseWriter, r *http.Request, err error) { - if h, ok := err.(http.Handler); ok { - h.ServeHTTP(w, r) - return - } - - op.errHandler.HandleError(w, err) -} - func unmarshal[Req any](r io.ReadCloser, req *Req) error { switch x := any(req).(type) { case encoding.BinaryUnmarshaler: diff --git a/rest/endpoint/endpoint_test.go b/rest/endpoint/endpoint_test.go index 073687f..48cbef3 100644 --- a/rest/endpoint/endpoint_test.go +++ b/rest/endpoint/endpoint_test.go @@ -40,18 +40,6 @@ func (x JsonContent) MarshalBinary() ([]byte, error) { return json.Marshal(x) } -type httpError struct { - status int -} - -func (httpError) Error() string { - return "" -} - -func (e httpError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(e.status) -} - type FailUnmarshalBinary struct{} var errUnmarshalBinary = errors.New("failed to unmarshal from binary") @@ -355,30 +343,8 @@ func TestEndpoint_ServeHTTP(t *testing.T) { } }) - t.Run("if the underlying error implements http.Handler", func(t *testing.T) { - errStatusCode := http.StatusServiceUnavailable - if !assert.NotEqual(t, DefaultErrorStatusCode, errStatusCode) { - return - } - - e := New( - HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { - return Empty{}, httpError{status: errStatusCode} - }), - ) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, "/", nil) - - e.ServeHTTP(w, r) - - resp := w.Result() - if !assert.Equal(t, errStatusCode, resp.StatusCode) { - return - } - }) - t.Run("if a required http header is missing", func(t *testing.T) { + var caughtError error e := New( noopHandler{}, Headers( @@ -387,6 +353,11 @@ func TestEndpoint_ServeHTTP(t *testing.T) { Required: true, }, ), + OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + caughtError = err + + w.WriteHeader(DefaultErrorStatusCode) + })), ) w := httptest.NewRecorder() @@ -395,12 +366,21 @@ func TestEndpoint_ServeHTTP(t *testing.T) { e.ServeHTTP(w, r) resp := w.Result() - if !assert.Equal(t, http.StatusBadRequest, resp.StatusCode) { + if !assert.Equal(t, DefaultErrorStatusCode, resp.StatusCode) { + return + } + + var herr MissingRequiredHeaderError + if !assert.ErrorAs(t, caughtError, &herr) { + return + } + if !assert.NotEmpty(t, herr.Error()) { return } }) t.Run("if a http header does not match its expected pattern", func(t *testing.T) { + var caughtError error e := New( noopHandler{}, Headers( @@ -409,6 +389,11 @@ func TestEndpoint_ServeHTTP(t *testing.T) { Pattern: "^[a-zA-Z]*$", }, ), + OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + caughtError = err + + w.WriteHeader(DefaultErrorStatusCode) + })), ) w := httptest.NewRecorder() @@ -418,12 +403,21 @@ func TestEndpoint_ServeHTTP(t *testing.T) { e.ServeHTTP(w, r) resp := w.Result() - if !assert.Equal(t, http.StatusBadRequest, resp.StatusCode) { + if !assert.Equal(t, DefaultErrorStatusCode, resp.StatusCode) { + return + } + + var herr InvalidHeaderError + if !assert.ErrorAs(t, caughtError, &herr) { + return + } + if !assert.NotEmpty(t, herr.Error()) { return } }) t.Run("if a required query param is missing", func(t *testing.T) { + var caughtError error e := New( noopHandler{}, QueryParams( @@ -432,6 +426,11 @@ func TestEndpoint_ServeHTTP(t *testing.T) { Required: true, }, ), + OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + caughtError = err + + w.WriteHeader(DefaultErrorStatusCode) + })), ) w := httptest.NewRecorder() @@ -440,12 +439,21 @@ func TestEndpoint_ServeHTTP(t *testing.T) { e.ServeHTTP(w, r) resp := w.Result() - if !assert.Equal(t, http.StatusBadRequest, resp.StatusCode) { + if !assert.Equal(t, DefaultErrorStatusCode, resp.StatusCode) { + return + } + + var qerr MissingRequiredQueryParamError + if !assert.ErrorAs(t, caughtError, &qerr) { + return + } + if !assert.NotEmpty(t, qerr.Error()) { return } }) t.Run("if a query param does not match its expected pattern", func(t *testing.T) { + var caughtError error e := New( noopHandler{}, QueryParams( @@ -454,6 +462,11 @@ func TestEndpoint_ServeHTTP(t *testing.T) { Pattern: "^[a-zA-Z]*$", }, ), + OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + caughtError = err + + w.WriteHeader(DefaultErrorStatusCode) + })), ) w := httptest.NewRecorder() @@ -462,16 +475,30 @@ func TestEndpoint_ServeHTTP(t *testing.T) { e.ServeHTTP(w, r) resp := w.Result() - if !assert.Equal(t, http.StatusBadRequest, resp.StatusCode) { + if !assert.Equal(t, DefaultErrorStatusCode, resp.StatusCode) { + return + } + + var qerr InvalidQueryParamError + if !assert.ErrorAs(t, caughtError, &qerr) { + return + } + if !assert.NotEmpty(t, qerr.Error()) { return } }) t.Run("if the request content type header does not match the content type from ContentTyper", func(t *testing.T) { + var caughtError error e := New( HandlerFunc[JsonContent, Empty](func(_ context.Context, _ JsonContent) (Empty, error) { return Empty{}, nil }), + OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + caughtError = err + + w.WriteHeader(DefaultErrorStatusCode) + })), ) w := httptest.NewRecorder() @@ -481,7 +508,15 @@ func TestEndpoint_ServeHTTP(t *testing.T) { e.ServeHTTP(w, r) resp := w.Result() - if !assert.Equal(t, http.StatusBadRequest, resp.StatusCode) { + if !assert.Equal(t, DefaultErrorStatusCode, resp.StatusCode) { + return + } + + var herr InvalidHeaderError + if !assert.ErrorAs(t, caughtError, &herr) { + return + } + if !assert.NotEmpty(t, herr.Error()) { return } }) diff --git a/rest/endpoint/validate.go b/rest/endpoint/validate.go index eb03a96..5eadfd1 100644 --- a/rest/endpoint/validate.go +++ b/rest/endpoint/validate.go @@ -31,11 +31,6 @@ func (e InvalidHeaderError) Error() string { return fmt.Sprintf("received invalid header for endpoint: %s", e.Header) } -// ServeHTTP implements the [http.Handler] interface. -func (InvalidHeaderError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) -} - // MissingRequiredHeaderError type MissingRequiredHeaderError struct { Header string @@ -46,11 +41,6 @@ func (e MissingRequiredHeaderError) Error() string { return fmt.Sprintf("missing required header for endpoint: %s", e.Header) } -// ServeHTTP implements the [http.Handler] interface. -func (MissingRequiredHeaderError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) -} - func validateHeader(h Header) func(*http.Request) error { var pattern *regexp.Regexp if h.Pattern != "" { @@ -82,11 +72,6 @@ func (e InvalidQueryParamError) Error() string { return fmt.Sprintf("received invalid query param for endpoint: %s", e.Param) } -// ServeHTTP implements the [http.Handler] interface. -func (InvalidQueryParamError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) -} - // MissingRequiredQueryParamError type MissingRequiredQueryParamError struct { Param string @@ -97,11 +82,6 @@ func (e MissingRequiredQueryParamError) Error() string { return fmt.Sprintf("missing required query param for endpoint: %s", e.Param) } -// ServeHTTP implements the [http.Handler] interface. -func (MissingRequiredQueryParamError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) -} - func validateQueryParam(qp QueryParam) func(*http.Request) error { var pattern *regexp.Regexp if qp.Pattern != "" { From 73abf53a6caded5557fea213e511982e928cc515 Mon Sep 17 00:00:00 2001 From: GitHub Action Date: Fri, 23 Aug 2024 03:16:16 +0000 Subject: [PATCH 07/12] chore(docs): updated coverage badge. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c099d08..b572e52 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) [![Go Reference](https://pkg.go.dev/badge/github.com/z5labs/bedrock.svg)](https://pkg.go.dev/github.com/z5labs/bedrock) [![Go Report Card](https://goreportcard.com/badge/github.com/z5labs/bedrock)](https://goreportcard.com/report/github.com/z5labs/bedrock) -![Coverage](https://img.shields.io/badge/Coverage-95.1%25-brightgreen) +![Coverage](https://img.shields.io/badge/Coverage-95.7%25-brightgreen) [![build](https://github.com/z5labs/bedrock/actions/workflows/build.yaml/badge.svg)](https://github.com/z5labs/bedrock/actions/workflows/build.yaml) **bedrock provides a minimal, modular and composable foundation for From af2126f23a5584a7e7e4d186516b981200d46fde Mon Sep 17 00:00:00 2001 From: Richard Carson Derr Date: Thu, 22 Aug 2024 23:38:38 -0400 Subject: [PATCH 08/12] feat(issue-248): add otel spans everywhere --- rest/endpoint/endpoint.go | 107 +++++++++++++++++++++++++------------- rest/endpoint/inject.go | 5 ++ rest/endpoint/validate.go | 9 +++- 3 files changed, 84 insertions(+), 37 deletions(-) diff --git a/rest/endpoint/endpoint.go b/rest/endpoint/endpoint.go index 3c47003..c9a52b1 100644 --- a/rest/endpoint/endpoint.go +++ b/rest/endpoint/endpoint.go @@ -18,6 +18,7 @@ import ( "github.com/swaggest/jsonschema-go" "github.com/swaggest/openapi-go/openapi3" + "go.opentelemetry.io/otel" ) // Empty @@ -42,9 +43,6 @@ type ErrorHandler interface { } type options struct { - method string - pattern string - pathParams map[PathParam]struct{} headerParams map[Header]struct{} queryParams map[QueryParam]struct{} @@ -347,58 +345,47 @@ func (op *Operation[Req, Resp]) OpenApi() *openapi3.Operation { // ServeHTTP implements the [http.Handler] interface. func (op *Operation[Req, Resp]) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := inject(r.Context(), w, r, op.injectors...) + spanCtx, span := otel.Tracer("endpoint").Start(r.Context(), "Operation.ServeHTTP") + defer span.End() + + ctx := inject(spanCtx, w, r, op.injectors...) - err := validateRequest(r, op.validators...) + err := validateRequest(ctx, r, op.validators...) if err != nil { - op.errHandler.HandleError(w, err) + op.handleError(ctx, w, err) return } var req Req - err = unmarshal(r.Body, &req) + err = unmarshal(ctx, r.Body, &req) if err != nil { - op.errHandler.HandleError(w, err) + op.handleError(ctx, w, err) return } - err = validate(req) + err = validate(ctx, req) if err != nil { - op.errHandler.HandleError(w, err) + op.handleError(ctx, w, err) return } resp, err := op.handler.Handle(ctx, req) if err != nil { - op.errHandler.HandleError(w, err) - return - } - - bm, ok := any(resp).(encoding.BinaryMarshaler) - if !ok { - w.WriteHeader(op.statusCode) + op.handleError(ctx, w, err) return } - b, err := bm.MarshalBinary() + err = op.writeResponse(ctx, w, resp) if err != nil { - op.errHandler.HandleError(w, err) - return - } - - if ct, ok := any(resp).(ContentTyper); ok { - w.Header().Set("Content-Type", ct.ContentType()) - } - - w.WriteHeader(op.statusCode) - _, err = io.Copy(w, bytes.NewReader(b)) - if err != nil { - op.errHandler.HandleError(w, err) + op.handleError(ctx, w, err) return } } -func unmarshal[Req any](r io.ReadCloser, req *Req) error { +func unmarshal[Req any](ctx context.Context, r io.ReadCloser, req *Req) error { + _, span := otel.Tracer("endpoint").Start(ctx, "unmarshal") + defer span.End() + switch x := any(req).(type) { case encoding.BinaryUnmarshaler: defer func() { @@ -407,12 +394,16 @@ func unmarshal[Req any](r io.ReadCloser, req *Req) error { b, err := io.ReadAll(r) if err != nil { + span.RecordError(err) return err } - return x.UnmarshalBinary(b) + err = x.UnmarshalBinary(b) + span.RecordError(err) + return err case io.ReaderFrom: _, err := x.ReadFrom(r) + span.RecordError(err) return err default: return nil @@ -424,9 +415,53 @@ type Validator interface { Validate() error } -func validate[Req any](req Req) error { - if v, ok := any(req).(Validator); ok { - return v.Validate() +func validate[Req any](ctx context.Context, req Req) error { + _, span := otel.Tracer("endpoint").Start(ctx, "validate") + defer span.End() + + v, ok := any(req).(Validator) + if !ok { + return nil + } + + err := v.Validate() + span.RecordError(err) + return err +} + +func (op *Operation[Req, Resp]) writeResponse(ctx context.Context, w http.ResponseWriter, resp Resp) error { + _, span := otel.Tracer("endpoint").Start(ctx, "Operation.writeResponse") + defer span.End() + + switch x := any(resp).(type) { + case io.WriterTo: + _, err := x.WriteTo(w) + span.RecordError(err) + return err + case encoding.BinaryMarshaler: + b, err := x.MarshalBinary() + if err != nil { + span.RecordError(err) + return err + } + + if ct, ok := any(resp).(ContentTyper); ok { + w.Header().Set("Content-Type", ct.ContentType()) + } + + w.WriteHeader(op.statusCode) + _, err = io.Copy(w, bytes.NewReader(b)) + span.RecordError(err) + return err + default: + w.WriteHeader(op.statusCode) + return nil } - return nil +} + +func (op *Operation[Req, Resp]) handleError(ctx context.Context, w http.ResponseWriter, err error) { + _, span := otel.Tracer("endpoint").Start(ctx, "Operation.handleError") + defer span.End() + + op.errHandler.HandleError(w, err) } diff --git a/rest/endpoint/inject.go b/rest/endpoint/inject.go index 0179345..c531e3c 100644 --- a/rest/endpoint/inject.go +++ b/rest/endpoint/inject.go @@ -9,11 +9,16 @@ import ( "context" "net/http" "net/url" + + "go.opentelemetry.io/otel" ) type injector func(context.Context, http.ResponseWriter, *http.Request) context.Context func inject(ctx context.Context, w http.ResponseWriter, r *http.Request, injectors ...injector) context.Context { + _, span := otel.Tracer("endpoint").Start(ctx, "inject") + defer span.End() + for _, injector := range injectors { ctx = injector(ctx, w, r) } diff --git a/rest/endpoint/validate.go b/rest/endpoint/validate.go index 5eadfd1..1370ed3 100644 --- a/rest/endpoint/validate.go +++ b/rest/endpoint/validate.go @@ -6,15 +6,22 @@ package endpoint import ( + "context" "fmt" "net/http" "regexp" + + "go.opentelemetry.io/otel" ) -func validateRequest(r *http.Request, validators ...func(*http.Request) error) error { +func validateRequest(ctx context.Context, r *http.Request, validators ...func(*http.Request) error) error { + _, span := otel.Tracer("endpoint").Start(ctx, "validateRequest") + defer span.End() + for _, validator := range validators { err := validator(r) if err != nil { + span.RecordError(err) return err } } From 503cbbcdf0fe2b25991fa1ab3f6ecef5dd47cdee Mon Sep 17 00:00:00 2001 From: GitHub Action Date: Fri, 23 Aug 2024 03:39:28 +0000 Subject: [PATCH 09/12] chore(docs): updated coverage badge. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b572e52..e3b5466 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) [![Go Reference](https://pkg.go.dev/badge/github.com/z5labs/bedrock.svg)](https://pkg.go.dev/github.com/z5labs/bedrock) [![Go Report Card](https://goreportcard.com/badge/github.com/z5labs/bedrock)](https://goreportcard.com/report/github.com/z5labs/bedrock) -![Coverage](https://img.shields.io/badge/Coverage-95.7%25-brightgreen) +![Coverage](https://img.shields.io/badge/Coverage-95.4%25-brightgreen) [![build](https://github.com/z5labs/bedrock/actions/workflows/build.yaml/badge.svg)](https://github.com/z5labs/bedrock/actions/workflows/build.yaml) **bedrock provides a minimal, modular and composable foundation for From 86f346f2b15fc772c330c2f9a5e4bb335b8b523e Mon Sep 17 00:00:00 2001 From: Richard Carson Derr Date: Thu, 22 Aug 2024 23:43:40 -0400 Subject: [PATCH 10/12] refactor(issue-248): remove empty type --- rest/endpoint/endpoint.go | 3 --- rest/endpoint/endpoint_test.go | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/rest/endpoint/endpoint.go b/rest/endpoint/endpoint.go index c9a52b1..1dd1adb 100644 --- a/rest/endpoint/endpoint.go +++ b/rest/endpoint/endpoint.go @@ -21,9 +21,6 @@ import ( "go.opentelemetry.io/otel" ) -// Empty -type Empty struct{} - // Handler type Handler[Req, Resp any] interface { Handle(context.Context, Req) (Resp, error) diff --git a/rest/endpoint/endpoint_test.go b/rest/endpoint/endpoint_test.go index 48cbef3..eeb3f58 100644 --- a/rest/endpoint/endpoint_test.go +++ b/rest/endpoint/endpoint_test.go @@ -18,6 +18,8 @@ import ( "github.com/stretchr/testify/assert" ) +type Empty struct{} + type noopHandler struct{} func (noopHandler) Handle(_ context.Context, _ Empty) (Empty, error) { From 967bfaa76fc69bc8fdc57f99b7e893760b198fda Mon Sep 17 00:00:00 2001 From: Richard Carson Derr Date: Fri, 23 Aug 2024 00:03:18 -0400 Subject: [PATCH 11/12] refactor(issue-248): almost forgot to add operation to spec --- example/simple_rest/app/app.go | 6 ++++-- rest/endpoint/endpoint.go | 12 +++++------ rest/endpoint/endpoint_test.go | 38 +++++++++++++++++----------------- rest/endpoint/openapi_test.go | 20 +++++++++--------- rest/rest.go | 29 +++++++++++++++++++++++--- rest/rest_example_test.go | 5 ++++- 6 files changed, 69 insertions(+), 41 deletions(-) diff --git a/example/simple_rest/app/app.go b/example/simple_rest/app/app.go index 2e36819..7b449bf 100644 --- a/example/simple_rest/app/app.go +++ b/example/simple_rest/app/app.go @@ -8,6 +8,7 @@ package app import ( "context" "log/slog" + "net/http" "os" "github.com/z5labs/bedrock" @@ -40,8 +41,9 @@ func Init(ctx context.Context, cfg Config) (bedrock.App, error) { restApp := rest.NewApp( rest.ListenOn(cfg.Http.Port), rest.Endpoint( - "POST /echo", - endpoint.New( + http.MethodPost, + "/echo", + endpoint.NewOperation( echoService, endpoint.Headers( endpoint.Header{ diff --git a/rest/endpoint/endpoint.go b/rest/endpoint/endpoint.go index 1dd1adb..9866c9c 100644 --- a/rest/endpoint/endpoint.go +++ b/rest/endpoint/endpoint.go @@ -48,7 +48,7 @@ type options struct { validators []func(*http.Request) error errHandler ErrorHandler - openapi *openapi3.Operation + openapi openapi3.Operation } // Option @@ -64,7 +64,7 @@ type Operation[Req, Resp any] struct { errHandler ErrorHandler - openapi *openapi3.Operation + openapi openapi3.Operation } const DefaultStatusCode = http.StatusOK @@ -271,8 +271,8 @@ func (f errorHandlerFunc) HandleError(w http.ResponseWriter, err error) { // DefaultErrorStatusCode const DefaultErrorStatusCode = http.StatusInternalServerError -// New initializes an Endpoint. -func New[Req, Resp any](handler Handler[Req, Resp], opts ...Option) *Operation[Req, Resp] { +// NewOperation initializes a Operation. +func NewOperation[Req, Resp any](handler Handler[Req, Resp], opts ...Option) *Operation[Req, Resp] { o := &options{ defaultStatusCode: DefaultStatusCode, pathParams: make(map[PathParam]struct{}), @@ -281,7 +281,7 @@ func New[Req, Resp any](handler Handler[Req, Resp], opts ...Option) *Operation[R errHandler: errorHandlerFunc(func(w http.ResponseWriter, err error) { w.WriteHeader(DefaultErrorStatusCode) }), - openapi: &openapi3.Operation{ + openapi: openapi3.Operation{ Responses: openapi3.Responses{ MapOfResponseOrRefValues: make(map[string]openapi3.ResponseOrRef), }, @@ -336,7 +336,7 @@ func initInjectors(o *options) []injector { return injectors } -func (op *Operation[Req, Resp]) OpenApi() *openapi3.Operation { +func (op *Operation[Req, Resp]) OpenApi() openapi3.Operation { return op.openapi } diff --git a/rest/endpoint/endpoint_test.go b/rest/endpoint/endpoint_test.go index eeb3f58..63cbd01 100644 --- a/rest/endpoint/endpoint_test.go +++ b/rest/endpoint/endpoint_test.go @@ -73,7 +73,7 @@ func (FailMarshalBinary) MarshalBinary() ([]byte, error) { func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will return the default success http status code", func(t *testing.T) { t.Run("if the underlying Handler succeeds with an empty response", func(t *testing.T) { - e := New(noopHandler{}) + e := NewOperation(noopHandler{}) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) @@ -87,7 +87,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the underlying Handler succeeds with a encoding.BinaryMarshaler response", func(t *testing.T) { - e := New( + e := NewOperation( HandlerFunc[Empty, JsonContent](func(_ context.Context, _ Empty) (JsonContent, error) { return JsonContent{Value: "hello, world"}, nil }), @@ -121,7 +121,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will inject path params", func(t *testing.T) { t.Run("if a valid http.ServeMux path param pattern is used", func(t *testing.T) { - e := New( + e := NewOperation( HandlerFunc[Empty, JsonContent](func(ctx context.Context, _ Empty) (JsonContent, error) { v := PathValue(ctx, "id") return JsonContent{Value: v}, nil @@ -163,7 +163,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will inject headers", func(t *testing.T) { t.Run("if a header is configured with the Headers option", func(t *testing.T) { - e := New( + e := NewOperation( HandlerFunc[Empty, JsonContent](func(ctx context.Context, _ Empty) (JsonContent, error) { v := HeaderValue(ctx, "test-header") return JsonContent{Value: v}, nil @@ -202,7 +202,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will inject query params", func(t *testing.T) { t.Run("if a query param is configured with the QueryParams option", func(t *testing.T) { - e := New( + e := NewOperation( HandlerFunc[Empty, JsonContent](func(ctx context.Context, _ Empty) (JsonContent, error) { v := QueryValue(ctx, "test-query") return JsonContent{Value: v}, nil @@ -245,7 +245,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { return } - e := New( + e := NewOperation( noopHandler{}, StatusCode(statusCode), ) @@ -267,7 +267,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { return } - e := New( + e := NewOperation( HandlerFunc[Empty, JsonContent](func(_ context.Context, _ Empty) (JsonContent, error) { return JsonContent{Value: "hello, world"}, nil }), @@ -302,7 +302,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will return non-success http status code", func(t *testing.T) { t.Run("if the underlying Handler returns an error", func(t *testing.T) { - e := New( + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, errors.New("failed") }), @@ -325,7 +325,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { return } - e := New( + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, errors.New("failed") }), @@ -347,7 +347,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("if a required http header is missing", func(t *testing.T) { var caughtError error - e := New( + e := NewOperation( noopHandler{}, Headers( Header{ @@ -383,7 +383,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("if a http header does not match its expected pattern", func(t *testing.T) { var caughtError error - e := New( + e := NewOperation( noopHandler{}, Headers( Header{ @@ -420,7 +420,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("if a required query param is missing", func(t *testing.T) { var caughtError error - e := New( + e := NewOperation( noopHandler{}, QueryParams( QueryParam{ @@ -456,7 +456,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("if a query param does not match its expected pattern", func(t *testing.T) { var caughtError error - e := New( + e := NewOperation( noopHandler{}, QueryParams( QueryParam{ @@ -492,7 +492,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("if the request content type header does not match the content type from ContentTyper", func(t *testing.T) { var caughtError error - e := New( + e := NewOperation( HandlerFunc[JsonContent, Empty](func(_ context.Context, _ JsonContent) (Empty, error) { return Empty{}, nil }), @@ -525,7 +525,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("if the request body fails to unmarshal", func(t *testing.T) { var caughtError error - e := New( + e := NewOperation( HandlerFunc[FailUnmarshalBinary, Empty](func(_ context.Context, _ FailUnmarshalBinary) (Empty, error) { return Empty{}, nil }), @@ -552,7 +552,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("if the unmarshaled request body is invalid", func(t *testing.T) { var caughtError error - e := New( + e := NewOperation( HandlerFunc[InvalidRequest, Empty](func(_ context.Context, _ InvalidRequest) (Empty, error) { return Empty{}, nil }), @@ -579,7 +579,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("if the response body fails to marshal itself to binary", func(t *testing.T) { var caughtError error - e := New( + e := NewOperation( HandlerFunc[Empty, FailMarshalBinary](func(_ context.Context, _ Empty) (FailMarshalBinary, error) { return FailMarshalBinary{}, nil }), @@ -607,7 +607,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will return response header", func(t *testing.T) { t.Run("if the response body implements ContentTyper", func(t *testing.T) { - e := New( + e := NewOperation( HandlerFunc[Empty, JsonContent](func(_ context.Context, _ Empty) (JsonContent, error) { return JsonContent{Value: "hello, world"}, nil }), @@ -642,7 +642,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the underlying Handler sets a custom response header using the context", func(t *testing.T) { - e := New( + e := NewOperation( HandlerFunc[Empty, Empty](func(ctx context.Context, _ Empty) (Empty, error) { SetResponseHeader(ctx, "Content-Type", "test-content-type") return Empty{}, nil diff --git a/rest/endpoint/openapi_test.go b/rest/endpoint/openapi_test.go index 3cbd45f..594c0f4 100644 --- a/rest/endpoint/openapi_test.go +++ b/rest/endpoint/openapi_test.go @@ -21,7 +21,7 @@ import ( func TestEndpoint_OpenApi(t *testing.T) { t.Run("will required path parameter", func(t *testing.T) { t.Run("if a http.ServeMux path parameter pattern is used", func(t *testing.T) { - e := New( + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), @@ -70,7 +70,7 @@ func TestEndpoint_OpenApi(t *testing.T) { Required: true, } - e := New( + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), @@ -116,7 +116,7 @@ func TestEndpoint_OpenApi(t *testing.T) { Required: true, } - e := New( + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), @@ -161,7 +161,7 @@ func TestEndpoint_OpenApi(t *testing.T) { Name: "myparam", } - e := New( + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), @@ -207,7 +207,7 @@ func TestEndpoint_OpenApi(t *testing.T) { Required: true, } - e := New( + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), @@ -248,7 +248,7 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set request body type", func(t *testing.T) { t.Run("if the request type implements ContentTyper interface", func(t *testing.T) { - e := New( + e := NewOperation( HandlerFunc[JsonContent, Empty](func(_ context.Context, _ JsonContent) (Empty, error) { return Empty{}, nil }), @@ -305,7 +305,7 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set response body type", func(t *testing.T) { t.Run("if the response type implements ContentTyper interface", func(t *testing.T) { - e := New( + e := NewOperation( HandlerFunc[Empty, JsonContent](func(_ context.Context, _ Empty) (JsonContent, error) { return JsonContent{}, nil }), @@ -365,7 +365,7 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set a empty response body", func(t *testing.T) { t.Run("if the response type does not implement ContentTyper", func(t *testing.T) { - e := New( + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), @@ -404,7 +404,7 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("if the Returns option is used with a http status code", func(t *testing.T) { statusCode := http.StatusBadRequest - e := New( + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), @@ -458,7 +458,7 @@ func TestEndpoint_OpenApi(t *testing.T) { return } - e := New( + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), diff --git a/rest/rest.go b/rest/rest.go index 6078238..77e3886 100644 --- a/rest/rest.go +++ b/rest/rest.go @@ -38,14 +38,17 @@ func ListenOn(port uint) Option { type Operation interface { http.Handler - OpenApi() *openapi3.Operation + OpenApi() openapi3.Operation } // Endpoint registers the [Operation] with both // the App wide OpenAPI spec and the App wide HTTP server. -func Endpoint(pattern string, op Operation) Option { +func Endpoint(method, pattern string, op Operation) Option { return func(app *App) { - // h.OpenApi(app.spec) + err := app.spec.AddOperation(method, pattern, op.OpenApi()) + if err != nil { + panic(err) + } app.mux.Handle( pattern, @@ -54,6 +57,26 @@ func Endpoint(pattern string, op Operation) Option { } } +// Title sets the title of the API in its OpenAPI spec. +// +// In order for your OpenAPI spec to be fully compliant +// with other tooling, this option is required. +func Title(s string) Option { + return func(a *App) { + a.spec.Info.Title = s + } +} + +// Version sets the API version in its OpenAPI spec. +// +// In order for your OpenAPI spec to be fully compliant +// with other tooling, this option is required. +func Version(s string) Option { + return func(a *App) { + a.spec.Info.Version = s + } +} + // App is a [bedrock.App] implementation to help simplify // building RESTful applications. type App struct { diff --git a/rest/rest_example_test.go b/rest/rest_example_test.go index 57e8ff3..5685880 100644 --- a/rest/rest_example_test.go +++ b/rest/rest_example_test.go @@ -73,9 +73,12 @@ func Example() { app := NewApp( listenOnRandomPort(addrCh), + Title("Example"), + Version("v0.0.0"), Endpoint( + http.MethodPost, "/", - endpoint.New( + endpoint.NewOperation( echoService{}, ), ), From c2d7a51bf737387b64be585fd166724c29915c86 Mon Sep 17 00:00:00 2001 From: GitHub Action Date: Fri, 23 Aug 2024 04:04:17 +0000 Subject: [PATCH 12/12] chore(docs): updated coverage badge. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e3b5466..d9ead4a 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) [![Go Reference](https://pkg.go.dev/badge/github.com/z5labs/bedrock.svg)](https://pkg.go.dev/github.com/z5labs/bedrock) [![Go Report Card](https://goreportcard.com/badge/github.com/z5labs/bedrock)](https://goreportcard.com/report/github.com/z5labs/bedrock) -![Coverage](https://img.shields.io/badge/Coverage-95.4%25-brightgreen) +![Coverage](https://img.shields.io/badge/Coverage-95.3%25-brightgreen) [![build](https://github.com/z5labs/bedrock/actions/workflows/build.yaml/badge.svg)](https://github.com/z5labs/bedrock/actions/workflows/build.yaml) **bedrock provides a minimal, modular and composable foundation for