diff --git a/README.md b/README.md index ff0e425..d9ead4a 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) [![Go Reference](https://pkg.go.dev/badge/github.com/z5labs/bedrock.svg)](https://pkg.go.dev/github.com/z5labs/bedrock) [![Go Report Card](https://goreportcard.com/badge/github.com/z5labs/bedrock)](https://goreportcard.com/report/github.com/z5labs/bedrock) -![Coverage](https://img.shields.io/badge/Coverage-95.0%25-brightgreen) +![Coverage](https://img.shields.io/badge/Coverage-95.3%25-brightgreen) [![build](https://github.com/z5labs/bedrock/actions/workflows/build.yaml/badge.svg)](https://github.com/z5labs/bedrock/actions/workflows/build.yaml) **bedrock provides a minimal, modular and composable foundation for diff --git a/example/simple_rest/app/app.go b/example/simple_rest/app/app.go index b2d39f8..7b449bf 100644 --- a/example/simple_rest/app/app.go +++ b/example/simple_rest/app/app.go @@ -8,6 +8,7 @@ package app import ( "context" "log/slog" + "net/http" "os" "github.com/z5labs/bedrock" @@ -39,10 +40,10 @@ func Init(ctx context.Context, cfg Config) (bedrock.App, error) { restApp := rest.NewApp( rest.ListenOn(cfg.Http.Port), - rest.Handle( + rest.Endpoint( + http.MethodPost, "/echo", - endpoint.Post( - "/echo", + endpoint.NewOperation( echoService, endpoint.Headers( endpoint.Header{ diff --git a/rest/endpoint/endpoint.go b/rest/endpoint/endpoint.go index acf3dca..9866c9c 100644 --- a/rest/endpoint/endpoint.go +++ b/rest/endpoint/endpoint.go @@ -12,19 +12,15 @@ import ( "fmt" "io" "net/http" - "reflect" "strconv" - "strings" "github.com/z5labs/bedrock/pkg/ptr" "github.com/swaggest/jsonschema-go" "github.com/swaggest/openapi-go/openapi3" + "go.opentelemetry.io/otel" ) -// Empty -type Empty struct{} - // Handler type Handler[Req, Resp any] interface { Handle(context.Context, Req) (Resp, error) @@ -44,29 +40,22 @@ type ErrorHandler interface { } type options struct { - method string - pattern string + pathParams map[PathParam]struct{} + headerParams map[Header]struct{} + queryParams map[QueryParam]struct{} defaultStatusCode int validators []func(*http.Request) error errHandler ErrorHandler - schemas map[string]*openapi3.Schema - pathParams []*openapi3.Parameter - headers []*openapi3.Parameter - queryParams []*openapi3.Parameter - request *openapi3.RequestBody - responses *openapi3.Responses + openapi openapi3.Operation } // Option type Option func(*options) -// Endpoint -type Endpoint[Req, Resp any] struct { - method string - pattern string - +// Operation +type Operation[Req, Resp any] struct { validators []func(*http.Request) error injectors []injector @@ -75,7 +64,7 @@ type Endpoint[Req, Resp any] struct { errHandler ErrorHandler - openapi func(*openapi3.Spec) + openapi openapi3.Operation } const DefaultStatusCode = http.StatusOK @@ -87,48 +76,29 @@ func StatusCode(statusCode int) Option { } } -type pathParam struct { - name string -} - -func parsePathParams(s string) []pathParam { - var params []pathParam - var found bool - for { - if len(s) == 0 { - return params - } - - _, s, found = strings.Cut(s, "{") - if !found { - return params - } - - i := strings.IndexByte(s, '}') - if i == -1 { - return params - } - - param := s[:i] - s = s[i:] - - name := strings.TrimSuffix(param, ".") - params = append(params, pathParam{ - name: name, - }) - } +// PathParam +type PathParam struct { + Name string + Pattern string + Required bool } -func pathParams(ps ...pathParam) Option { +// PathParams +func PathParams(ps ...PathParam) Option { return func(o *options) { for _, p := range ps { - o.pathParams = append(o.pathParams, &openapi3.Parameter{ - In: openapi3.ParameterInPath, - Name: p.name, - Required: ptr.Ref(true), - Schema: &openapi3.SchemaOrRef{ - Schema: &openapi3.Schema{ - Type: ptr.Ref(openapi3.SchemaTypeString), + o.pathParams[p] = struct{}{} + + o.openapi.Parameters = append(o.openapi.Parameters, openapi3.ParameterOrRef{ + Parameter: &openapi3.Parameter{ + In: openapi3.ParameterInPath, + Name: p.Name, + Required: ptr.Ref(p.Required), + Schema: &openapi3.SchemaOrRef{ + Schema: &openapi3.Schema{ + Type: ptr.Ref(openapi3.SchemaTypeString), + Pattern: ptr.Ref(p.Pattern), + }, }, }, }) @@ -147,13 +117,18 @@ type Header struct { func Headers(hs ...Header) Option { return func(o *options) { for _, h := range hs { - o.headers = append(o.headers, &openapi3.Parameter{ - In: openapi3.ParameterInHeader, - Name: h.Name, - Required: ptr.Ref(h.Required), - Schema: &openapi3.SchemaOrRef{ - Schema: &openapi3.Schema{ - Type: ptr.Ref(openapi3.SchemaTypeString), + o.headerParams[h] = struct{}{} + + o.openapi.Parameters = append(o.openapi.Parameters, openapi3.ParameterOrRef{ + Parameter: &openapi3.Parameter{ + In: openapi3.ParameterInHeader, + Name: h.Name, + Required: ptr.Ref(h.Required), + Schema: &openapi3.SchemaOrRef{ + Schema: &openapi3.Schema{ + Type: ptr.Ref(openapi3.SchemaTypeString), + Pattern: ptr.Ref(h.Pattern), + }, }, }, }) @@ -174,13 +149,18 @@ type QueryParam struct { func QueryParams(qps ...QueryParam) Option { return func(o *options) { for _, qp := range qps { - o.queryParams = append(o.queryParams, &openapi3.Parameter{ - In: openapi3.ParameterInQuery, - Name: qp.Name, - Required: ptr.Ref(qp.Required), - Schema: &openapi3.SchemaOrRef{ - Schema: &openapi3.Schema{ - Type: ptr.Ref(openapi3.SchemaTypeString), + o.queryParams[qp] = struct{}{} + + o.openapi.Parameters = append(o.openapi.Parameters, openapi3.ParameterOrRef{ + Parameter: &openapi3.Parameter{ + In: openapi3.ParameterInQuery, + Name: qp.Name, + Required: ptr.Ref(qp.Required), + Schema: &openapi3.SchemaOrRef{ + Schema: &openapi3.Schema{ + Type: ptr.Ref(openapi3.SchemaTypeString), + Pattern: ptr.Ref(qp.Pattern), + }, }, }, }) @@ -198,10 +178,6 @@ type ContentTyper interface { // Accepts func Accepts[Req any]() Option { return func(o *options) { - if o.request == nil { - o.request = new(openapi3.RequestBody) - } - contentType := "" var req Req @@ -224,18 +200,12 @@ func Accepts[Req any]() Option { var schemaOrRef openapi3.SchemaOrRef schemaOrRef.FromJSONSchema(schema.ToSchemaOrBool()) - typeName := reflect.TypeOf(req).Name() - schemaRef := fmt.Sprintf("#/components/schemas/%s", typeName) - o.schemas[typeName] = schemaOrRef.Schema - - o.request = &openapi3.RequestBody{ - Required: ptr.Ref(true), - Content: map[string]openapi3.MediaType{ - contentType: { - Schema: &openapi3.SchemaOrRef{ - SchemaReference: &openapi3.SchemaReference{ - Ref: schemaRef, - }, + o.openapi.RequestBody = &openapi3.RequestBodyOrRef{ + RequestBody: &openapi3.RequestBody{ + Required: ptr.Ref(true), + Content: map[string]openapi3.MediaType{ + contentType: { + Schema: &schemaOrRef, }, }, }, @@ -246,13 +216,7 @@ func Accepts[Req any]() Option { // Returns func Returns(status int) Option { return func(o *options) { - if o.responses == nil { - o.responses = &openapi3.Responses{ - MapOfResponseOrRefValues: make(map[string]openapi3.ResponseOrRef), - } - } - - o.responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ + o.openapi.Responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ Response: &openapi3.Response{}, } } @@ -261,19 +225,10 @@ func Returns(status int) Option { // ReturnsWith func ReturnsWith[Resp any](status int) Option { return func(o *options) { - if o.responses == nil { - o.responses = &openapi3.Responses{ - MapOfResponseOrRefValues: make(map[string]openapi3.ResponseOrRef), - } - } - - contentType := "" - var resp Resp - if ct, ok := any(resp).(ContentTyper); ok { - contentType = ct.ContentType() - } else { - o.responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ + ct, ok := any(resp).(ContentTyper) + if !ok { + o.openapi.Responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ Response: &openapi3.Response{}, } return @@ -288,19 +243,11 @@ func ReturnsWith[Resp any](status int) Option { var schemaOrRef openapi3.SchemaOrRef schemaOrRef.FromJSONSchema(schema.ToSchemaOrBool()) - typeName := reflect.TypeOf(resp).Name() - schemaRef := fmt.Sprintf("#/components/schemas/%s", typeName) - o.schemas[typeName] = schemaOrRef.Schema - - o.responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ + o.openapi.Responses.MapOfResponseOrRefValues[strconv.Itoa(status)] = openapi3.ResponseOrRef{ Response: &openapi3.Response{ Content: map[string]openapi3.MediaType{ - contentType: { - Schema: &openapi3.SchemaOrRef{ - SchemaReference: &openapi3.SchemaReference{ - Ref: schemaRef, - }, - }, + ct.ContentType(): { + Schema: &schemaOrRef, }, }, }, @@ -324,41 +271,38 @@ func (f errorHandlerFunc) HandleError(w http.ResponseWriter, err error) { // DefaultErrorStatusCode const DefaultErrorStatusCode = http.StatusInternalServerError -// New initializes an Endpoint. -func New[Req, Resp any](method string, pattern string, handler Handler[Req, Resp], opts ...Option) *Endpoint[Req, Resp] { +// NewOperation initializes a Operation. +func NewOperation[Req, Resp any](handler Handler[Req, Resp], opts ...Option) *Operation[Req, Resp] { o := &options{ - method: method, - pattern: pattern, defaultStatusCode: DefaultStatusCode, - validators: []func(*http.Request) error{ - validateMethod(method), - }, + pathParams: make(map[PathParam]struct{}), + headerParams: make(map[Header]struct{}), + queryParams: make(map[QueryParam]struct{}), errHandler: errorHandlerFunc(func(w http.ResponseWriter, err error) { w.WriteHeader(DefaultErrorStatusCode) }), - schemas: make(map[string]*openapi3.Schema), + openapi: openapi3.Operation{ + Responses: openapi3.Responses{ + MapOfResponseOrRefValues: make(map[string]openapi3.ResponseOrRef), + }, + }, } - for _, opt := range withBuiltinOptions[Req, Resp](pattern, opts...) { + for _, opt := range withBuiltinOptions[Req, Resp](opts...) { opt(o) } - return &Endpoint[Req, Resp]{ - method: method, - pattern: pattern, + return &Operation[Req, Resp]{ injectors: initInjectors(o), validators: o.validators, statusCode: o.defaultStatusCode, handler: handler, errHandler: o.errHandler, - openapi: setOpenApiSpec(o), + openapi: o.openapi, } } -func withBuiltinOptions[Req, Resp any](pattern string, opts ...Option) []Option { - parsedPathParams := parsePathParams(pattern) - opts = append(opts, pathParams(parsedPathParams...)) - +func withBuiltinOptions[Req, Resp any](opts ...Option) []Option { var req Req if _, ok := any(req).(ContentTyper); ok { opts = append(opts, Accepts[Req]()) @@ -380,10 +324,10 @@ func withBuiltinOptions[Req, Resp any](pattern string, opts ...Option) []Option func initInjectors(o *options) []injector { injectors := []injector{injectResponseHeaders} - for _, p := range o.pathParams { + for p := range o.pathParams { injectors = append(injectors, injectPathParam(p.Name)) } - if len(o.headers) > 0 { + if len(o.headerParams) > 0 { injectors = append(injectors, injectHeaders) } if len(o.queryParams) > 0 { @@ -392,105 +336,53 @@ func initInjectors(o *options) []injector { return injectors } -// Get returns an Endpoint configured for handling HTTP GET requests. -func Get[Req, Resp any](pattern string, handler Handler[Req, Resp], opts ...Option) *Endpoint[Req, Resp] { - return New(http.MethodGet, pattern, handler, opts...) -} - -// Post returns an Endpoint configured for handling HTTP POST requests. -func Post[Req, Resp any](pattern string, handler Handler[Req, Resp], opts ...Option) *Endpoint[Req, Resp] { - return New(http.MethodPost, pattern, handler, opts...) -} - -// Put returns an Endpoint configured for handling HTTP PUT requests. -func Put[Req, Resp any](pattern string, handler Handler[Req, Resp], opts ...Option) *Endpoint[Req, Resp] { - return New(http.MethodPut, pattern, handler, opts...) -} - -// Delete returns an Endpoint configured for handling HTTP DELETE requests. -func Delete[Req, Resp any](pattern string, handler Handler[Req, Resp], opts ...Option) *Endpoint[Req, Resp] { - return New(http.MethodDelete, pattern, handler, opts...) -} - -// Method returns the HTTP method which this endpoint -// is configured to handle requests for. -func (e *Endpoint[Req, Resp]) Method() string { - return e.method -} - -// Pattern returns HTTP path pattern for this endpoint. -func (e *Endpoint[Req, Resp]) Pattern() string { - return e.pattern -} - -// OpenApi allows the endpoint to register itself with an OpenAPI spec. -func (e *Endpoint[Req, Resp]) OpenApi(spec *openapi3.Spec) { - e.openapi(spec) +func (op *Operation[Req, Resp]) OpenApi() openapi3.Operation { + return op.openapi } // ServeHTTP implements the [http.Handler] interface. -func (e *Endpoint[Req, Resp]) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := inject(r.Context(), w, r, e.injectors...) +func (op *Operation[Req, Resp]) ServeHTTP(w http.ResponseWriter, r *http.Request) { + spanCtx, span := otel.Tracer("endpoint").Start(r.Context(), "Operation.ServeHTTP") + defer span.End() + + ctx := inject(spanCtx, w, r, op.injectors...) - err := validateRequest(r, e.validators...) + err := validateRequest(ctx, r, op.validators...) if err != nil { - e.handleError(w, r, err) + op.handleError(ctx, w, err) return } var req Req - err = unmarshal(r.Body, &req) + err = unmarshal(ctx, r.Body, &req) if err != nil { - e.handleError(w, r, err) + op.handleError(ctx, w, err) return } - err = validate(req) + err = validate(ctx, req) if err != nil { - e.handleError(w, r, err) + op.handleError(ctx, w, err) return } - resp, err := e.handler.Handle(ctx, req) + resp, err := op.handler.Handle(ctx, req) if err != nil { - e.handleError(w, r, err) + op.handleError(ctx, w, err) return } - bm, ok := any(resp).(encoding.BinaryMarshaler) - if !ok { - w.WriteHeader(e.statusCode) - return - } - - b, err := bm.MarshalBinary() + err = op.writeResponse(ctx, w, resp) if err != nil { - e.handleError(w, r, err) - return - } - - if ct, ok := any(resp).(ContentTyper); ok { - w.Header().Set("Content-Type", ct.ContentType()) - } - - w.WriteHeader(e.statusCode) - _, err = io.Copy(w, bytes.NewReader(b)) - if err != nil { - e.handleError(w, r, err) + op.handleError(ctx, w, err) return } } -func (e *Endpoint[Req, Resp]) handleError(w http.ResponseWriter, r *http.Request, err error) { - if h, ok := err.(http.Handler); ok { - h.ServeHTTP(w, r) - return - } - - e.errHandler.HandleError(w, err) -} +func unmarshal[Req any](ctx context.Context, r io.ReadCloser, req *Req) error { + _, span := otel.Tracer("endpoint").Start(ctx, "unmarshal") + defer span.End() -func unmarshal[Req any](r io.ReadCloser, req *Req) error { switch x := any(req).(type) { case encoding.BinaryUnmarshaler: defer func() { @@ -499,12 +391,16 @@ func unmarshal[Req any](r io.ReadCloser, req *Req) error { b, err := io.ReadAll(r) if err != nil { + span.RecordError(err) return err } - return x.UnmarshalBinary(b) + err = x.UnmarshalBinary(b) + span.RecordError(err) + return err case io.ReaderFrom: _, err := x.ReadFrom(r) + span.RecordError(err) return err default: return nil @@ -516,9 +412,53 @@ type Validator interface { Validate() error } -func validate[Req any](req Req) error { - if v, ok := any(req).(Validator); ok { - return v.Validate() +func validate[Req any](ctx context.Context, req Req) error { + _, span := otel.Tracer("endpoint").Start(ctx, "validate") + defer span.End() + + v, ok := any(req).(Validator) + if !ok { + return nil + } + + err := v.Validate() + span.RecordError(err) + return err +} + +func (op *Operation[Req, Resp]) writeResponse(ctx context.Context, w http.ResponseWriter, resp Resp) error { + _, span := otel.Tracer("endpoint").Start(ctx, "Operation.writeResponse") + defer span.End() + + switch x := any(resp).(type) { + case io.WriterTo: + _, err := x.WriteTo(w) + span.RecordError(err) + return err + case encoding.BinaryMarshaler: + b, err := x.MarshalBinary() + if err != nil { + span.RecordError(err) + return err + } + + if ct, ok := any(resp).(ContentTyper); ok { + w.Header().Set("Content-Type", ct.ContentType()) + } + + w.WriteHeader(op.statusCode) + _, err = io.Copy(w, bytes.NewReader(b)) + span.RecordError(err) + return err + default: + w.WriteHeader(op.statusCode) + return nil } - return nil +} + +func (op *Operation[Req, Resp]) handleError(ctx context.Context, w http.ResponseWriter, err error) { + _, span := otel.Tracer("endpoint").Start(ctx, "Operation.handleError") + defer span.End() + + op.errHandler.HandleError(w, err) } diff --git a/rest/endpoint/endpoint_test.go b/rest/endpoint/endpoint_test.go index 3a63c96..63cbd01 100644 --- a/rest/endpoint/endpoint_test.go +++ b/rest/endpoint/endpoint_test.go @@ -18,6 +18,8 @@ import ( "github.com/stretchr/testify/assert" ) +type Empty struct{} + type noopHandler struct{} func (noopHandler) Handle(_ context.Context, _ Empty) (Empty, error) { @@ -40,18 +42,6 @@ func (x JsonContent) MarshalBinary() ([]byte, error) { return json.Marshal(x) } -type httpError struct { - status int -} - -func (httpError) Error() string { - return "" -} - -func (e httpError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(e.status) -} - type FailUnmarshalBinary struct{} var errUnmarshalBinary = errors.New("failed to unmarshal from binary") @@ -80,110 +70,13 @@ func (FailMarshalBinary) MarshalBinary() ([]byte, error) { return nil, errMarshalBinary } -func TestGet(t *testing.T) { - t.Run("will return 405 http status code", func(t *testing.T) { - t.Run("if a non GET method is used to call its returned endpoint", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, - noopHandler{}, - ) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, pattern, nil) - - e.ServeHTTP(w, r) - - resp := w.Result() - if !assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) { - return - } - }) - }) -} - -func TestPost(t *testing.T) { - t.Run("will return 405 http status code", func(t *testing.T) { - t.Run("if a non POST method is used to call its returned endpoint", func(t *testing.T) { - pattern := "/" - - e := Post( - pattern, - noopHandler{}, - ) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) - - e.ServeHTTP(w, r) - - resp := w.Result() - if !assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) { - return - } - }) - }) -} - -func TestPut(t *testing.T) { - t.Run("will return 405 http status code", func(t *testing.T) { - t.Run("if a non PUT method is used to call its returned endpoint", func(t *testing.T) { - pattern := "/" - - e := Put( - pattern, - noopHandler{}, - ) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) - - e.ServeHTTP(w, r) - - resp := w.Result() - if !assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) { - return - } - }) - }) -} - -func TestDelete(t *testing.T) { - t.Run("will return 405 http status code", func(t *testing.T) { - t.Run("if a non DELETE method is used to call its returned endpoint", func(t *testing.T) { - pattern := "/" - - e := Delete( - pattern, - noopHandler{}, - ) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) - - e.ServeHTTP(w, r) - - resp := w.Result() - if !assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) { - return - } - }) - }) -} - func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will return the default success http status code", func(t *testing.T) { t.Run("if the underlying Handler succeeds with an empty response", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, - noopHandler{}, - ) + e := NewOperation(noopHandler{}) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -194,17 +87,14 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the underlying Handler succeeds with a encoding.BinaryMarshaler response", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := NewOperation( HandlerFunc[Empty, JsonContent](func(_ context.Context, _ Empty) (JsonContent, error) { return JsonContent{Value: "hello, world"}, nil }), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -231,14 +121,14 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will inject path params", func(t *testing.T) { t.Run("if a valid http.ServeMux path param pattern is used", func(t *testing.T) { - pattern := "/{id}" - - e := Get( - pattern, + e := NewOperation( HandlerFunc[Empty, JsonContent](func(ctx context.Context, _ Empty) (JsonContent, error) { v := PathValue(ctx, "id") return JsonContent{Value: v}, nil }), + PathParams(PathParam{ + Name: "id", + }), ) w := httptest.NewRecorder() @@ -247,7 +137,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { // for path params a http.ServeMux must be used since // Endpoint doesn't support it directly mux := http.NewServeMux() - mux.Handle(pattern, e) + mux.Handle("GET /{id}", e) mux.ServeHTTP(w, r) resp := w.Result() @@ -273,10 +163,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will inject headers", func(t *testing.T) { t.Run("if a header is configured with the Headers option", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := NewOperation( HandlerFunc[Empty, JsonContent](func(ctx context.Context, _ Empty) (JsonContent, error) { v := HeaderValue(ctx, "test-header") return JsonContent{Value: v}, nil @@ -287,7 +174,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("test-header", "hello, world") e.ServeHTTP(w, r) @@ -315,10 +202,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will inject query params", func(t *testing.T) { t.Run("if a query param is configured with the QueryParams option", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := NewOperation( HandlerFunc[Empty, JsonContent](func(ctx context.Context, _ Empty) (JsonContent, error) { v := QueryValue(ctx, "test-query") return JsonContent{Value: v}, nil @@ -329,7 +213,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern+"?test-query=abc123", nil) + r := httptest.NewRequest(http.MethodGet, "/?test-query=abc123", nil) e.ServeHTTP(w, r) @@ -356,20 +240,18 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will return custom success http status code", func(t *testing.T) { t.Run("if the StatusCode option is used and the underlying Handler succeeds with an empty response", func(t *testing.T) { - pattern := "/" statusCode := http.StatusCreated if !assert.NotEqual(t, DefaultStatusCode, statusCode) { return } - e := Get( - pattern, + e := NewOperation( noopHandler{}, StatusCode(statusCode), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -380,14 +262,12 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the StatusCode option is used and the underlying Handler succeeds with a encoding.BinaryMarshaler response", func(t *testing.T) { - pattern := "/" statusCode := http.StatusCreated if !assert.NotEqual(t, DefaultStatusCode, statusCode) { return } - e := Get( - pattern, + e := NewOperation( HandlerFunc[Empty, JsonContent](func(_ context.Context, _ Empty) (JsonContent, error) { return JsonContent{Value: "hello, world"}, nil }), @@ -395,7 +275,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -422,17 +302,14 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will return non-success http status code", func(t *testing.T) { t.Run("if the underlying Handler returns an error", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, errors.New("failed") }), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -443,14 +320,12 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if a custom error handler is set", func(t *testing.T) { - pattern := "/" errStatusCode := http.StatusServiceUnavailable if !assert.NotEqual(t, DefaultErrorStatusCode, errStatusCode) { return } - e := Get( - pattern, + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, errors.New("failed") }), @@ -460,32 +335,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) - - e.ServeHTTP(w, r) - - resp := w.Result() - if !assert.Equal(t, errStatusCode, resp.StatusCode) { - return - } - }) - - t.Run("if the underlying error implements http.Handler", func(t *testing.T) { - pattern := "/" - errStatusCode := http.StatusServiceUnavailable - if !assert.NotEqual(t, DefaultErrorStatusCode, errStatusCode) { - return - } - - e := Get( - pattern, - HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { - return Empty{}, httpError{status: errStatusCode} - }), - ) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -495,30 +345,9 @@ func TestEndpoint_ServeHTTP(t *testing.T) { } }) - t.Run("if the http request is for the wrong http method", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, - noopHandler{}, - ) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, pattern, nil) - - e.ServeHTTP(w, r) - - resp := w.Result() - if !assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) { - return - } - }) - t.Run("if a required http header is missing", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + var caughtError error + e := NewOperation( noopHandler{}, Headers( Header{ @@ -526,24 +355,35 @@ func TestEndpoint_ServeHTTP(t *testing.T) { Required: true, }, ), + OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + caughtError = err + + w.WriteHeader(DefaultErrorStatusCode) + })), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) resp := w.Result() - if !assert.Equal(t, http.StatusBadRequest, resp.StatusCode) { + if !assert.Equal(t, DefaultErrorStatusCode, resp.StatusCode) { + return + } + + var herr MissingRequiredHeaderError + if !assert.ErrorAs(t, caughtError, &herr) { + return + } + if !assert.NotEmpty(t, herr.Error()) { return } }) t.Run("if a http header does not match its expected pattern", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + var caughtError error + e := NewOperation( noopHandler{}, Headers( Header{ @@ -551,25 +391,36 @@ func TestEndpoint_ServeHTTP(t *testing.T) { Pattern: "^[a-zA-Z]*$", }, ), + OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + caughtError = err + + w.WriteHeader(DefaultErrorStatusCode) + })), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("Authorization", "abc123") e.ServeHTTP(w, r) resp := w.Result() - if !assert.Equal(t, http.StatusBadRequest, resp.StatusCode) { + if !assert.Equal(t, DefaultErrorStatusCode, resp.StatusCode) { + return + } + + var herr InvalidHeaderError + if !assert.ErrorAs(t, caughtError, &herr) { + return + } + if !assert.NotEmpty(t, herr.Error()) { return } }) t.Run("if a required query param is missing", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + var caughtError error + e := NewOperation( noopHandler{}, QueryParams( QueryParam{ @@ -577,24 +428,35 @@ func TestEndpoint_ServeHTTP(t *testing.T) { Required: true, }, ), + OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + caughtError = err + + w.WriteHeader(DefaultErrorStatusCode) + })), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) resp := w.Result() - if !assert.Equal(t, http.StatusBadRequest, resp.StatusCode) { + if !assert.Equal(t, DefaultErrorStatusCode, resp.StatusCode) { + return + } + + var qerr MissingRequiredQueryParamError + if !assert.ErrorAs(t, caughtError, &qerr) { + return + } + if !assert.NotEmpty(t, qerr.Error()) { return } }) t.Run("if a query param does not match its expected pattern", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + var caughtError error + e := NewOperation( noopHandler{}, QueryParams( QueryParam{ @@ -602,47 +464,68 @@ func TestEndpoint_ServeHTTP(t *testing.T) { Pattern: "^[a-zA-Z]*$", }, ), + OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + caughtError = err + + w.WriteHeader(DefaultErrorStatusCode) + })), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern+"?test=abc123", nil) + r := httptest.NewRequest(http.MethodGet, "/?test=abc123", nil) e.ServeHTTP(w, r) resp := w.Result() - if !assert.Equal(t, http.StatusBadRequest, resp.StatusCode) { + if !assert.Equal(t, DefaultErrorStatusCode, resp.StatusCode) { + return + } + + var qerr InvalidQueryParamError + if !assert.ErrorAs(t, caughtError, &qerr) { + return + } + if !assert.NotEmpty(t, qerr.Error()) { return } }) t.Run("if the request content type header does not match the content type from ContentTyper", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + var caughtError error + e := NewOperation( HandlerFunc[JsonContent, Empty](func(_ context.Context, _ JsonContent) (Empty, error) { return Empty{}, nil }), + OnError(errorHandlerFunc(func(w http.ResponseWriter, err error) { + caughtError = err + + w.WriteHeader(DefaultErrorStatusCode) + })), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("Content-Type", "application/xml") e.ServeHTTP(w, r) resp := w.Result() - if !assert.Equal(t, http.StatusBadRequest, resp.StatusCode) { + if !assert.Equal(t, DefaultErrorStatusCode, resp.StatusCode) { + return + } + + var herr InvalidHeaderError + if !assert.ErrorAs(t, caughtError, &herr) { + return + } + if !assert.NotEmpty(t, herr.Error()) { return } }) t.Run("if the request body fails to unmarshal", func(t *testing.T) { - pattern := "/" - var caughtError error - e := Post( - pattern, + e := NewOperation( HandlerFunc[FailUnmarshalBinary, Empty](func(_ context.Context, _ FailUnmarshalBinary) (Empty, error) { return Empty{}, nil }), @@ -654,7 +537,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, pattern, strings.NewReader(`{}`)) + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`)) e.ServeHTTP(w, r) @@ -668,11 +551,8 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the unmarshaled request body is invalid", func(t *testing.T) { - pattern := "/" - var caughtError error - e := Post( - pattern, + e := NewOperation( HandlerFunc[InvalidRequest, Empty](func(_ context.Context, _ InvalidRequest) (Empty, error) { return Empty{}, nil }), @@ -684,7 +564,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, pattern, strings.NewReader(`{}`)) + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`)) e.ServeHTTP(w, r) @@ -698,11 +578,8 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the response body fails to marshal itself to binary", func(t *testing.T) { - pattern := "/" - var caughtError error - e := Post( - pattern, + e := NewOperation( HandlerFunc[Empty, FailMarshalBinary](func(_ context.Context, _ Empty) (FailMarshalBinary, error) { return FailMarshalBinary{}, nil }), @@ -714,7 +591,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, pattern, strings.NewReader(`{}`)) + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`)) e.ServeHTTP(w, r) @@ -730,17 +607,14 @@ func TestEndpoint_ServeHTTP(t *testing.T) { t.Run("will return response header", func(t *testing.T) { t.Run("if the response body implements ContentTyper", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := NewOperation( HandlerFunc[Empty, JsonContent](func(_ context.Context, _ Empty) (JsonContent, error) { return JsonContent{Value: "hello, world"}, nil }), ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) @@ -768,10 +642,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { }) t.Run("if the underlying Handler sets a custom response header using the context", func(t *testing.T) { - pattern := "/" - - e := Get( - pattern, + e := NewOperation( HandlerFunc[Empty, Empty](func(ctx context.Context, _ Empty) (Empty, error) { SetResponseHeader(ctx, "Content-Type", "test-content-type") return Empty{}, nil @@ -779,7 +650,7 @@ func TestEndpoint_ServeHTTP(t *testing.T) { ) w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, pattern, nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) e.ServeHTTP(w, r) diff --git a/rest/endpoint/inject.go b/rest/endpoint/inject.go index 0179345..c531e3c 100644 --- a/rest/endpoint/inject.go +++ b/rest/endpoint/inject.go @@ -9,11 +9,16 @@ import ( "context" "net/http" "net/url" + + "go.opentelemetry.io/otel" ) type injector func(context.Context, http.ResponseWriter, *http.Request) context.Context func inject(ctx context.Context, w http.ResponseWriter, r *http.Request, injectors ...injector) context.Context { + _, span := otel.Tracer("endpoint").Start(ctx, "inject") + defer span.End() + for _, injector := range injectors { ctx = injector(ctx, w, r) } diff --git a/rest/endpoint/openapi.go b/rest/endpoint/openapi.go deleted file mode 100644 index 299cecf..0000000 --- a/rest/endpoint/openapi.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) 2024 Z5Labs and Contributors -// -// This software is released under the MIT License. -// https://opensource.org/licenses/MIT - -package endpoint - -import ( - "github.com/swaggest/openapi-go/openapi3" -) - -func setOpenApiSpec(o *options) func(*openapi3.Spec) { - return compose( - addSchemas(o.schemas), - addParameters(o.method, o.pattern, o.pathParams...), - addParameters(o.method, o.pattern, o.headers...), - addParameters(o.method, o.pattern, o.queryParams...), - addRequestBody(o.method, o.pattern, o.request), - addResponses(o.method, o.pattern, o.responses), - ) -} - -func compose(fs ...func(*openapi3.Spec)) func(*openapi3.Spec) { - return func(s *openapi3.Spec) { - for _, f := range fs { - f(s) - } - } -} - -func addSchemas(schemas map[string]*openapi3.Schema) func(*openapi3.Spec) { - return func(s *openapi3.Spec) { - if len(schemas) == 0 { - return - } - - if s.Components == nil { - s.Components = new(openapi3.Components) - } - - comps := s.Components - if comps.Schemas == nil { - comps.Schemas = new(openapi3.ComponentsSchemas) - } - - compSchemas := comps.Schemas - if compSchemas.MapOfSchemaOrRefValues == nil { - compSchemas.MapOfSchemaOrRefValues = make(map[string]openapi3.SchemaOrRef, len(schemas)) - } - - for name, schema := range schemas { - compSchemas.MapOfSchemaOrRefValues[name] = openapi3.SchemaOrRef{ - Schema: schema, - } - } - } -} - -func addParameters(method, pattern string, params ...*openapi3.Parameter) func(*openapi3.Spec) { - return func(s *openapi3.Spec) { - if len(params) == 0 { - return - } - - if s.Paths.MapOfPathItemValues == nil { - s.Paths.MapOfPathItemValues = make(map[string]openapi3.PathItem) - } - - pathItemVals := s.Paths.MapOfPathItemValues - pathItem, ok := pathItemVals[pattern] - if !ok { - pathItem = openapi3.PathItem{ - MapOfOperationValues: make(map[string]openapi3.Operation), - } - } - - opVals := pathItem.MapOfOperationValues - opVal, ok := opVals[method] - if !ok { - opVal = openapi3.Operation{} - } - - for _, param := range params { - opVal.Parameters = append(opVal.Parameters, openapi3.ParameterOrRef{ - Parameter: param, - }) - } - - opVals[method] = opVal - pathItemVals[pattern] = pathItem - } -} - -func addRequestBody(method, pattern string, reqBody *openapi3.RequestBody) func(*openapi3.Spec) { - return func(s *openapi3.Spec) { - if reqBody == nil { - return - } - - if s.Paths.MapOfPathItemValues == nil { - s.Paths.MapOfPathItemValues = make(map[string]openapi3.PathItem) - } - - pathItemVals := s.Paths.MapOfPathItemValues - pathItem, ok := pathItemVals[pattern] - if !ok { - pathItem = openapi3.PathItem{ - MapOfOperationValues: make(map[string]openapi3.Operation), - } - } - - opVals := pathItem.MapOfOperationValues - opVal, ok := opVals[method] - if !ok { - opVal = openapi3.Operation{} - } - - opVal.RequestBody = &openapi3.RequestBodyOrRef{ - RequestBody: reqBody, - } - - opVals[method] = opVal - pathItemVals[pattern] = pathItem - } -} - -func addResponses(method, pattern string, responses *openapi3.Responses) func(*openapi3.Spec) { - return func(s *openapi3.Spec) { - if responses == nil { - return - } - - if s.Paths.MapOfPathItemValues == nil { - s.Paths.MapOfPathItemValues = make(map[string]openapi3.PathItem) - } - - pathItemVals := s.Paths.MapOfPathItemValues - pathItem, ok := pathItemVals[pattern] - if !ok { - pathItem = openapi3.PathItem{ - MapOfOperationValues: make(map[string]openapi3.Operation), - } - } - - opVals := pathItem.MapOfOperationValues - opVal, ok := opVals[method] - if !ok { - opVal = openapi3.Operation{} - } - - opVal.Responses = *responses - - opVals[method] = opVal - pathItemVals[pattern] = pathItem - } -} diff --git a/rest/endpoint/openapi_test.go b/rest/endpoint/openapi_test.go index cb2b7da..594c0f4 100644 --- a/rest/endpoint/openapi_test.go +++ b/rest/endpoint/openapi_test.go @@ -9,9 +9,7 @@ import ( "context" "encoding/json" "net/http" - "path" "strconv" - "strings" "testing" "github.com/z5labs/bedrock/pkg/ptr" @@ -23,50 +21,27 @@ import ( func TestEndpoint_OpenApi(t *testing.T) { t.Run("will required path parameter", func(t *testing.T) { t.Run("if a http.ServeMux path parameter pattern is used", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/{id}" - - e := New( - method, - pattern, + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), + PathParams(PathParam{ + Name: "id", + Required: true, + }), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - op := ops[method] params := op.Parameters if !assert.Len(t, params, 1) { return @@ -90,55 +65,29 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set non-required header parameter", func(t *testing.T) { t.Run("if a header is provided with the Headers option", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" header := Header{ Name: "MyHeader", Required: true, } - e := New( - method, - pattern, + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), Headers(header), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - op := ops[method] params := op.Parameters if !assert.Len(t, params, 1) { return @@ -162,55 +111,29 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set required header parameter", func(t *testing.T) { t.Run("if a header is provided with the Headers option", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" header := Header{ Name: "MyHeader", Required: true, } - e := New( - method, - pattern, + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), Headers(header), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - op := ops[method] params := op.Parameters if !assert.Len(t, params, 1) { return @@ -234,54 +157,28 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set non-required query param", func(t *testing.T) { t.Run("if a query param is provided with the QueryParams option", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" queryParam := QueryParam{ Name: "myparam", } - e := New( - method, - pattern, + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), QueryParams(queryParam), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - op := ops[method] params := op.Parameters if !assert.Len(t, params, 1) { return @@ -305,55 +202,29 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set required query param", func(t *testing.T) { t.Run("if a query param is provided with the QueryParams option", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" queryParam := QueryParam{ Name: "myparam", Required: true, } - e := New( - method, - pattern, + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), QueryParams(queryParam), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - op := ops[method] params := op.Parameters if !assert.Len(t, params, 1) { return @@ -377,50 +248,24 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set request body type", func(t *testing.T) { t.Run("if the request type implements ContentTyper interface", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" - - e := New( - method, - pattern, + e := NewOperation( HandlerFunc[JsonContent, Empty](func(_ context.Context, _ JsonContent) (Empty, error) { return Empty{}, nil }), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - reqBodyOrRef := ops[method].RequestBody + reqBodyOrRef := op.RequestBody if !assert.NotNil(t, reqBodyOrRef) { return } @@ -443,31 +288,7 @@ func TestEndpoint_OpenApi(t *testing.T) { return } - schemaRef := schemaOrRef.SchemaReference - if !assert.NotNil(t, schemaRef) { - return - } - _, schemaRefName := path.Split(schemaRef.Ref) - - comps := spec.Components - if !assert.NotNil(t, comps) { - return - } - - schemas := comps.Schemas - if !assert.NotNil(t, schemas) { - return - } - - schemaOrRefValues := schemas.MapOfSchemaOrRefValues - if !assert.Len(t, schemaOrRefValues, 1) { - return - } - if !assert.Contains(t, schemaOrRefValues, schemaRefName) { - return - } - - schema := schemaOrRefValues[schemaRefName].Schema + schema := schemaOrRef.Schema if !assert.NotNil(t, schema) { return } @@ -484,50 +305,24 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set response body type", func(t *testing.T) { t.Run("if the response type implements ContentTyper interface", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" - - e := New( - method, - pattern, + e := NewOperation( HandlerFunc[Empty, JsonContent](func(_ context.Context, _ Empty) (JsonContent, error) { return JsonContent{}, nil }), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - respOrRefValues := ops[method].Responses.MapOfResponseOrRefValues + respOrRefValues := op.Responses.MapOfResponseOrRefValues if !assert.Len(t, respOrRefValues, 1) { return } @@ -553,31 +348,7 @@ func TestEndpoint_OpenApi(t *testing.T) { return } - schemaRef := schemaOrRef.SchemaReference - if !assert.NotNil(t, schemaRef) { - return - } - _, respRefName := path.Split(schemaRef.Ref) - - comps := spec.Components - if !assert.NotNil(t, comps) { - return - } - - schemas := comps.Schemas - if !assert.NotNil(t, schemas) { - return - } - - schemaOrRefValues := schemas.MapOfSchemaOrRefValues - if !assert.Len(t, schemaOrRefValues, 1) { - return - } - if !assert.Contains(t, schemaOrRefValues, respRefName) { - return - } - - schema := schemaOrRefValues[respRefName].Schema + schema := schemaOrRef.Schema if !assert.NotNil(t, schema) { return } @@ -594,50 +365,24 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will set a empty response body", func(t *testing.T) { t.Run("if the response type does not implement ContentTyper", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" - - e := New( - method, - pattern, + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - respOrRefValues := ops[method].Responses.MapOfResponseOrRefValues + respOrRefValues := op.Responses.MapOfResponseOrRefValues if !assert.Len(t, respOrRefValues, 1) { return } @@ -657,52 +402,27 @@ func TestEndpoint_OpenApi(t *testing.T) { }) t.Run("if the Returns option is used with a http status code", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" statusCode := http.StatusBadRequest - e := New( - method, - pattern, + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), Returns(statusCode), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - respOrRefValues := ops[method].Responses.MapOfResponseOrRefValues + respOrRefValues := op.Responses.MapOfResponseOrRefValues if !assert.Len(t, respOrRefValues, 2) { return } @@ -733,56 +453,30 @@ func TestEndpoint_OpenApi(t *testing.T) { t.Run("will override default response status code", func(t *testing.T) { t.Run("if DefaultStatusCode option is used", func(t *testing.T) { - method := strings.ToLower(http.MethodPost) - pattern := "/" - statusCode := http.StatusCreated if !assert.NotEqual(t, statusCode, DefaultStatusCode) { return } - e := New( - method, - pattern, + e := NewOperation( HandlerFunc[Empty, Empty](func(_ context.Context, _ Empty) (Empty, error) { return Empty{}, nil }), StatusCode(statusCode), ) - refSpec := &openapi3.Spec{ - Openapi: "3.0.3", - } - e.OpenApi(refSpec) - - b, err := json.Marshal(refSpec) + b, err := json.Marshal(e.OpenApi()) if !assert.Nil(t, err) { return } - var spec openapi3.Spec - err = json.Unmarshal(b, &spec) + var op openapi3.Operation + err = json.Unmarshal(b, &op) if !assert.Nil(t, err) { return } - pathItems := spec.Paths.MapOfPathItemValues - if !assert.Len(t, pathItems, 1) { - return - } - if !assert.Contains(t, pathItems, pattern) { - return - } - - ops := pathItems[pattern].MapOfOperationValues - if !assert.Len(t, ops, 1) { - return - } - if !assert.Contains(t, ops, method) { - return - } - - respOrRefValues := ops[method].Responses.MapOfResponseOrRefValues + respOrRefValues := op.Responses.MapOfResponseOrRefValues if !assert.Len(t, respOrRefValues, 1) { return } diff --git a/rest/endpoint/validate.go b/rest/endpoint/validate.go index dac808d..1370ed3 100644 --- a/rest/endpoint/validate.go +++ b/rest/endpoint/validate.go @@ -6,46 +6,28 @@ package endpoint import ( + "context" "fmt" "net/http" "regexp" + + "go.opentelemetry.io/otel" ) -func validateRequest(r *http.Request, validators ...func(*http.Request) error) error { +func validateRequest(ctx context.Context, r *http.Request, validators ...func(*http.Request) error) error { + _, span := otel.Tracer("endpoint").Start(ctx, "validateRequest") + defer span.End() + for _, validator := range validators { err := validator(r) if err != nil { + span.RecordError(err) return err } } return nil } -// InvalidMethodError represents when a request was sent to an endpoint -// for the incorrect method. -type InvalidMethodError struct { - Method string -} - -// Error implements [error] interface. -func (e InvalidMethodError) Error() string { - return fmt.Sprintf("received invalid method for endpoint: %s", e.Method) -} - -// ServeHTTP implements the [http.Handler] interface. -func (InvalidMethodError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusMethodNotAllowed) -} - -func validateMethod(method string) func(*http.Request) error { - return func(r *http.Request) error { - if r.Method == method { - return nil - } - return InvalidMethodError{Method: r.Method} - } -} - // InvalidHeaderError type InvalidHeaderError struct { Header string @@ -56,11 +38,6 @@ func (e InvalidHeaderError) Error() string { return fmt.Sprintf("received invalid header for endpoint: %s", e.Header) } -// ServeHTTP implements the [http.Handler] interface. -func (InvalidHeaderError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) -} - // MissingRequiredHeaderError type MissingRequiredHeaderError struct { Header string @@ -71,11 +48,6 @@ func (e MissingRequiredHeaderError) Error() string { return fmt.Sprintf("missing required header for endpoint: %s", e.Header) } -// ServeHTTP implements the [http.Handler] interface. -func (MissingRequiredHeaderError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) -} - func validateHeader(h Header) func(*http.Request) error { var pattern *regexp.Regexp if h.Pattern != "" { @@ -107,11 +79,6 @@ func (e InvalidQueryParamError) Error() string { return fmt.Sprintf("received invalid query param for endpoint: %s", e.Param) } -// ServeHTTP implements the [http.Handler] interface. -func (InvalidQueryParamError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) -} - // MissingRequiredQueryParamError type MissingRequiredQueryParamError struct { Param string @@ -122,11 +89,6 @@ func (e MissingRequiredQueryParamError) Error() string { return fmt.Sprintf("missing required query param for endpoint: %s", e.Param) } -// ServeHTTP implements the [http.Handler] interface. -func (MissingRequiredQueryParamError) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) -} - func validateQueryParam(qp QueryParam) func(*http.Request) error { var pattern *regexp.Regexp if qp.Pattern != "" { diff --git a/rest/rest.go b/rest/rest.go index 7841bf8..77e3886 100644 --- a/rest/rest.go +++ b/rest/rest.go @@ -33,27 +33,50 @@ func ListenOn(port uint) Option { } } -// Handler represents anything that can handle HTTP requests +// Operation represents anything that can handle HTTP requests // and provide OpenAPI documentation for itself. -type Handler interface { +type Operation interface { http.Handler - OpenApi(*openapi3.Spec) + OpenApi() openapi3.Operation } -// Handler registers the provider [Handler] with both +// Endpoint registers the [Operation] with both // the App wide OpenAPI spec and the App wide HTTP server. -func Handle(pattern string, h Handler) Option { +func Endpoint(method, pattern string, op Operation) Option { return func(app *App) { - h.OpenApi(app.spec) + err := app.spec.AddOperation(method, pattern, op.OpenApi()) + if err != nil { + panic(err) + } app.mux.Handle( pattern, - otelhttp.WithRouteTag(pattern, h), + otelhttp.WithRouteTag(pattern, op), ) } } +// Title sets the title of the API in its OpenAPI spec. +// +// In order for your OpenAPI spec to be fully compliant +// with other tooling, this option is required. +func Title(s string) Option { + return func(a *App) { + a.spec.Info.Title = s + } +} + +// Version sets the API version in its OpenAPI spec. +// +// In order for your OpenAPI spec to be fully compliant +// with other tooling, this option is required. +func Version(s string) Option { + return func(a *App) { + a.spec.Info.Version = s + } +} + // App is a [bedrock.App] implementation to help simplify // building RESTful applications. type App struct { diff --git a/rest/rest_example_test.go b/rest/rest_example_test.go index 0916a61..5685880 100644 --- a/rest/rest_example_test.go +++ b/rest/rest_example_test.go @@ -73,10 +73,12 @@ func Example() { app := NewApp( listenOnRandomPort(addrCh), - Handle( + Title("Example"), + Version("v0.0.0"), + Endpoint( + http.MethodPost, "/", - endpoint.Post( - "/", + endpoint.NewOperation( echoService{}, ), ),