diff --git a/.github/workflows/lint-test-cover.yml b/.github/workflows/lint-test-cover.yml index 6850943..9bee599 100644 --- a/.github/workflows/lint-test-cover.yml +++ b/.github/workflows/lint-test-cover.yml @@ -28,14 +28,14 @@ jobs: - name: Install Golang uses: actions/setup-go@v3 with: - go-version: 1.18 + go-version: 1.20.4 check-latest: true - name: Print Go Version run: go version - name: Cache go modules - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | ~/.cache/go-build @@ -46,7 +46,7 @@ jobs: - name: Test and produce coverage profile id: coverage - run: echo "::set-output name=total::$(make test-coverage)" + run: echo "total=$(make test-coverage)" >> $GITHUB_OUTPUT - name: Update README.md with coverage run: | diff --git a/.gitignore b/.gitignore index 6e61439..5a61a83 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.DS_Store + # Jet Brains .idea diff --git a/Makefile b/Makefile index f4b7913..90f55e3 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,6 @@ COVERAGE_PROFILE_FILE := profile.cov COVERAGE_HTML_FILE := coverage.html SHELL := bash -eo pipefail -c - $(COVERAGE_PROFILE_FILE): $(shell find router examples) @go test -v -race -failfast \ -coverprofile=$@ \ @@ -17,13 +16,18 @@ test-coverage: $(COVERAGE_PROFILE_FILE) $(COVERAGE_HTML_FILE): $(COVERAGE_PROFILE_FILE) @go tool cover -html=$(COVERAGE_PROFILE_FILE) -o $(COVERAGE_HTML_FILE) +FILE ?= file0 + .PHONY: show-coverage show-coverage: $(COVERAGE_HTML_FILE) + $(eval ANCHOR:=$(shell cat $(COVERAGE_HTML_FILE) | grep -E '' | grep $(FILE) | sed -n 's/^.*value="\(.*\)".*$$/\1/p')) + @sed -E -i '' 's/select\("file[0-9]+"\);/select("$(ANCHOR)");/g' $(COVERAGE_HTML_FILE) @open $(COVERAGE_HTML_FILE) + .PHONY: test test: - @go test -race -failfast ./... + @go test -race ./... .PHONY: clean-coverage clean-coverage: diff --git a/README.md b/README.md index 49f6872..a53b421 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # Cellotape - Beta - OpenAPI Router for Go -![](https://badgen.net/badge/coverage/25/green?icon=github) +![99.2%](https://badgen.net/badge/coverage/99.2%25/green?icon=github) Cellotape requires Go 1.18 or above. diff --git a/examples/hello_world_example/main_test.go b/examples/hello_world_example/main_test.go index 4917d41..b9cfb66 100644 --- a/examples/hello_world_example/main_test.go +++ b/examples/hello_world_example/main_test.go @@ -23,7 +23,6 @@ func TestHelloWorldExample(t *testing.T) { handler, err := router.NewOpenAPIRouter(spec). WithOperation("greet", api.GreetOperationHandler). AsHandler() - fmt.Println(err) require.NoError(t, err) ts := httptest.NewServer(handler) diff --git a/examples/todo_list_app_example/main.go b/examples/todo_list_app_example/main.go index cc9d8dc..44caf34 100644 --- a/examples/todo_list_app_example/main.go +++ b/examples/todo_list_app_example/main.go @@ -36,7 +36,7 @@ func mainHandler() error { return err } port := 8080 - fmt.Printf("Starting HTTP server on port %d\n", port) + log.Printf("Starting HTTP server on port %d\n", port) if err = http.ListenAndServe(fmt.Sprintf(":%d", port), handler); err != nil { return err } diff --git a/examples/todo_list_app_example/main_test.go b/examples/todo_list_app_example/main_test.go index 329beae..df58e43 100644 --- a/examples/todo_list_app_example/main_test.go +++ b/examples/todo_list_app_example/main_test.go @@ -34,7 +34,7 @@ func TestGetAllTasks(t *testing.T) { assert.JSONEq(t, `{ "results": [], "page": 0, - "pageSize": 0, + "pageSize": 10, "isLast": true }`, string(response)) } @@ -52,7 +52,7 @@ func TestCreateNewTaskAndGetIt(t *testing.T) { req, err := http.NewRequest("POST", fmt.Sprintf("%s/tasks", ts.URL), request) require.NoError(t, err) req.Header.Set("Authorization", "Bearer secret") - //req.Header.Set("Content-Type", "application/json") + req.Header.Set("Content-Type", "application/json") client := http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -77,6 +77,65 @@ func TestCreateNewTaskAndGetIt(t *testing.T) { assert.JSONEq(t, taskJson, string(data)) } +func TestRequestQueryParamViolateSchemaValidations(t *testing.T) { + ts := initAPI(t) + defer ts.Close() + req, err := http.NewRequest("GET", fmt.Sprintf("%s/tasks?pageSize=30", ts.URL), nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer secret") + client := http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + response, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, + `invalid request query param. parameter "pageSize" in query has an error: number must be at most 20`, + string(response)) +} + +func TestRequestPathParamViolateSchemaValidations(t *testing.T) { + ts := initAPI(t) + defer ts.Close() + req, err := http.NewRequest("GET", fmt.Sprintf("%s/tasks/123", ts.URL), nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer secret") + client := http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + response, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, + `invalid request path param. parameter "id" in path has an error: minimum string length is 36`, + string(response)) +} + +func TestRequestBodyViolateSchemaValidations(t *testing.T) { + ts := initAPI(t) + defer ts.Close() + taskJson := `{ + "summary": "code first approach", + "description": "add support for code first approach", + "status": "archived" + }` + request := bytes.NewBufferString(taskJson) + + req, err := http.NewRequest("POST", fmt.Sprintf("%s/tasks", ts.URL), request) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer secret") + req.Header.Set("Content-Type", "application/json") + client := http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + response, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, + `invalid request body. request body has an error: doesn't match schema #/components/schemas/Task: value "archived" is not one of the allowed values`, + string(response)) +} + func initAPI(t *testing.T) *httptest.Server { spec, err := router.NewSpecFromData(specData) require.NoError(t, err) diff --git a/examples/todo_list_app_example/middlewares/auth.go b/examples/todo_list_app_example/middlewares/auth.go index 89295f6..a12c34d 100644 --- a/examples/todo_list_app_example/middlewares/auth.go +++ b/examples/todo_list_app_example/middlewares/auth.go @@ -6,13 +6,14 @@ import ( "github.com/piiano/cellotape/examples/todo_list_app_example/models" r "github.com/piiano/cellotape/router" + "github.com/piiano/cellotape/router/utils" ) const token = "secret" var authHeader = fmt.Sprintf("Bearer %s", token) -var AuthMiddleware = r.NewHandler(func(c *r.Context, req r.Request[r.Nil, r.Nil, r.Nil]) (r.Response[authResponses], error) { +var AuthMiddleware = r.NewHandler(func(c *r.Context, req r.Request[utils.Nil, utils.Nil, utils.Nil]) (r.Response[authResponses], error) { if req.Headers.Get("Authorization") != authHeader { return r.SendJSON(authResponses{Unauthorized: models.HttpError{ Error: "Unauthorized", diff --git a/examples/todo_list_app_example/middlewares/logger.go b/examples/todo_list_app_example/middlewares/logger.go index 801c337..6440313 100644 --- a/examples/todo_list_app_example/middlewares/logger.go +++ b/examples/todo_list_app_example/middlewares/logger.go @@ -7,16 +7,16 @@ import ( r "github.com/piiano/cellotape/router" ) -var LoggerMiddleware = r.NewHandler(loggerHandler) +var LoggerMiddleware = r.RawHandler(loggerHandler) -func loggerHandler(c *r.Context, request r.Request[r.Nil, r.Nil, r.Nil]) (r.Response[any], error) { +func loggerHandler(c *r.Context) error { start := time.Now() response, err := c.Next() duration := time.Since(start) if err != nil { log.Printf("[ERROR] error occurred: %s. - %s - [%s] %s\n", err.Error(), duration, c.Request.Method, c.Request.URL.Path) - return r.Response[any]{}, nil + return err } log.Printf("[INFO] (status %d | %d bytes | %s) - [%s] %s\n", response.Status, len(response.Body), duration, c.Request.Method, c.Request.URL.Path) - return r.Response[any]{}, nil + return nil } diff --git a/examples/todo_list_app_example/middlewares/powered_by.go b/examples/todo_list_app_example/middlewares/powered_by.go index 88909ae..5ac7cf9 100644 --- a/examples/todo_list_app_example/middlewares/powered_by.go +++ b/examples/todo_list_app_example/middlewares/powered_by.go @@ -2,11 +2,12 @@ package middlewares import ( r "github.com/piiano/cellotape/router" + "github.com/piiano/cellotape/router/utils" ) var PoweredByMiddleware = r.NewHandler(poweredByHandler) -func poweredByHandler(c *r.Context, _ r.Request[r.Nil, r.Nil, r.Nil]) (r.Response[any], error) { +func poweredByHandler(c *r.Context, _ r.Request[utils.Nil, utils.Nil, utils.Nil]) (r.Response[any], error) { c.Writer.Header().Add("X-Powered-By", "Piiano OpenAPI Router") _, err := c.Next() return r.Response[any]{}, err diff --git a/examples/todo_list_app_example/openapi.yaml b/examples/todo_list_app_example/openapi.yaml index 625376c..dc3c0c6 100644 --- a/examples/todo_list_app_example/openapi.yaml +++ b/examples/todo_list_app_example/openapi.yaml @@ -120,6 +120,8 @@ components: Id: type: string format: uuid + minLength: 36 + maxLength: 36 nullable: false Identifiable: type: object diff --git a/examples/todo_list_app_example/rest/create_new_task_handler.go b/examples/todo_list_app_example/rest/create_new_task_handler.go index c76b32c..9b7f76f 100644 --- a/examples/todo_list_app_example/rest/create_new_task_handler.go +++ b/examples/todo_list_app_example/rest/create_new_task_handler.go @@ -4,10 +4,11 @@ import ( m "github.com/piiano/cellotape/examples/todo_list_app_example/models" "github.com/piiano/cellotape/examples/todo_list_app_example/services" r "github.com/piiano/cellotape/router" + "github.com/piiano/cellotape/router/utils" ) func createNewTaskOperation(tasks services.TasksService) r.Handler { - return r.NewHandler(func(c *r.Context, request r.Request[m.Task, r.Nil, r.Nil]) (r.Response[createNewTaskResponses], error) { + return r.NewHandler(func(c *r.Context, request r.Request[m.Task, utils.Nil, utils.Nil]) (r.Response[createNewTaskResponses], error) { id := tasks.CreateTask(request.Body) return r.SendOKJSON(createNewTaskResponses{OK: m.Identifiable{ID: id}}), nil }) diff --git a/examples/todo_list_app_example/rest/delete_task_by_id_handler.go b/examples/todo_list_app_example/rest/delete_task_by_id_handler.go index e953717..64fb759 100644 --- a/examples/todo_list_app_example/rest/delete_task_by_id_handler.go +++ b/examples/todo_list_app_example/rest/delete_task_by_id_handler.go @@ -9,10 +9,11 @@ import ( m "github.com/piiano/cellotape/examples/todo_list_app_example/models" "github.com/piiano/cellotape/examples/todo_list_app_example/services" r "github.com/piiano/cellotape/router" + "github.com/piiano/cellotape/router/utils" ) func deleteTaskByIDOperation(tasks services.TasksService) r.Handler { - return r.NewHandler(func(_ *r.Context, request r.Request[r.Nil, idPathParam, r.Nil]) (r.Response[deleteTaskByIDResponses], error) { + return r.NewHandler(func(_ *r.Context, request r.Request[utils.Nil, idPathParam, utils.Nil]) (r.Response[deleteTaskByIDResponses], error) { id, err := uuid.Parse(request.PathParams.ID) if err != nil { return r.SendJSON(deleteTaskByIDResponses{ @@ -35,7 +36,7 @@ func deleteTaskByIDOperation(tasks services.TasksService) r.Handler { } type deleteTaskByIDResponses struct { - NoContent r.Nil `status:"204"` + NoContent utils.Nil `status:"204"` BadRequest m.HttpError `status:"400"` Gone m.HttpError `status:"410"` } diff --git a/examples/todo_list_app_example/rest/get_tasks_by_id_handler.go b/examples/todo_list_app_example/rest/get_tasks_by_id_handler.go index a61eeab..190b764 100644 --- a/examples/todo_list_app_example/rest/get_tasks_by_id_handler.go +++ b/examples/todo_list_app_example/rest/get_tasks_by_id_handler.go @@ -9,10 +9,11 @@ import ( m "github.com/piiano/cellotape/examples/todo_list_app_example/models" "github.com/piiano/cellotape/examples/todo_list_app_example/services" r "github.com/piiano/cellotape/router" + "github.com/piiano/cellotape/router/utils" ) func getTaskByIDOperation(tasks services.TasksService) r.Handler { - return r.NewHandler(func(_ *r.Context, request r.Request[r.Nil, idPathParam, r.Nil]) (r.Response[getTaskByIDResponses], error) { + return r.NewHandler(func(_ *r.Context, request r.Request[utils.Nil, idPathParam, utils.Nil]) (r.Response[getTaskByIDResponses], error) { id, err := uuid.Parse(request.PathParams.ID) if err != nil { return r.SendJSON(getTaskByIDResponses{ diff --git a/examples/todo_list_app_example/rest/get_tasks_page_handler.go b/examples/todo_list_app_example/rest/get_tasks_page_handler.go index d11694a..48b2721 100644 --- a/examples/todo_list_app_example/rest/get_tasks_page_handler.go +++ b/examples/todo_list_app_example/rest/get_tasks_page_handler.go @@ -6,10 +6,11 @@ import ( m "github.com/piiano/cellotape/examples/todo_list_app_example/models" "github.com/piiano/cellotape/examples/todo_list_app_example/services" r "github.com/piiano/cellotape/router" + "github.com/piiano/cellotape/router/utils" ) func getTasksPageOperation(tasks services.TasksService) r.Handler { - return r.NewHandler(func(_ *r.Context, request r.Request[r.Nil, r.Nil, paginationQueryParams]) (r.Response[getTasksPageResponses], error) { + return r.NewHandler(func(_ *r.Context, request r.Request[utils.Nil, utils.Nil, paginationQueryParams]) (r.Response[getTasksPageResponses], error) { tasksPage := tasks.GetTasksPage(request.QueryParams.Page, request.QueryParams.PageSize) return r.SendOKJSON(getTasksPageResponses{OK: tasksPage}, http.Header{"Cache-Control": {"max-age=10"}}), nil }) diff --git a/examples/todo_list_app_example/rest/update_task_by_id_handler.go b/examples/todo_list_app_example/rest/update_task_by_id_handler.go index 6b9ba99..24dd40d 100644 --- a/examples/todo_list_app_example/rest/update_task_by_id_handler.go +++ b/examples/todo_list_app_example/rest/update_task_by_id_handler.go @@ -9,10 +9,11 @@ import ( m "github.com/piiano/cellotape/examples/todo_list_app_example/models" "github.com/piiano/cellotape/examples/todo_list_app_example/services" r "github.com/piiano/cellotape/router" + "github.com/piiano/cellotape/router/utils" ) func updateTaskByIDOperation(tasks services.TasksService) r.Handler { - return r.NewHandler(func(_ *r.Context, request r.Request[m.Task, idPathParam, r.Nil]) (r.Response[updateTaskByIDResponses], error) { + return r.NewHandler(func(_ *r.Context, request r.Request[m.Task, idPathParam, utils.Nil]) (r.Response[updateTaskByIDResponses], error) { id, err := uuid.Parse(request.PathParams.ID) if err != nil { return r.SendJSON(updateTaskByIDResponses{ @@ -35,7 +36,7 @@ func updateTaskByIDOperation(tasks services.TasksService) r.Handler { } type updateTaskByIDResponses struct { - NoContent r.Nil `status:"204"` + NoContent utils.Nil `status:"204"` BadRequest m.HttpError `status:"400"` NotFound m.HttpError `status:"404"` } diff --git a/go.mod b/go.mod index 4181c17..0dae9db 100644 --- a/go.mod +++ b/go.mod @@ -1,24 +1,20 @@ module github.com/piiano/cellotape -go 1.18 +go 1.20 -retract ( - v1.0.0 // Published accidentally. - v2.0.0 // Published accidentally. -) +retract v1.0.0 // Published accidentally. require ( - github.com/getkin/kin-openapi v0.94.0 + github.com/getkin/kin-openapi v0.112.0 github.com/gin-gonic/gin v1.7.7 github.com/google/uuid v1.3.0 github.com/invopop/jsonschema v0.4.0 github.com/julienschmidt/httprouter v1.3.0 - github.com/stretchr/testify v1.7.1 + github.com/stretchr/testify v1.8.1 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/ghodss/yaml v1.0.0 // indirect github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/swag v0.19.5 // indirect github.com/go-playground/locales v0.14.0 // indirect @@ -27,12 +23,14 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/google/go-cmp v0.5.7 // indirect github.com/iancoleman/orderedmap v0.0.0-20190318233801-ac98e3ecb4b0 // indirect + github.com/invopop/yaml v0.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/leodido/go-urn v1.2.1 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/ugorji/go/codec v1.2.7 // indirect golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect diff --git a/go.sum b/go.sum index 97f6cb9..16c3cb4 100644 --- a/go.sum +++ b/go.sum @@ -2,10 +2,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/getkin/kin-openapi v0.94.0 h1:bAxg2vxgnHHHoeefVdmGbR+oxtJlcv5HsJJa3qmAHuo= -github.com/getkin/kin-openapi v0.94.0/go.mod h1:LWZfzOd7PRy8GJ1dJ6mCU6tNdSfOwRac1BUPam4aw6Q= -github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/getkin/kin-openapi v0.112.0 h1:lnLXx3bAG53EJVI4E/w0N8i1Y/vUZUEsnrXkgnfn7/Y= +github.com/getkin/kin-openapi v0.112.0/go.mod h1:QtwUNt0PAAgIIBEvFWYfB7dfngxtAaqCX1zYHMZDeK8= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= @@ -34,11 +32,14 @@ github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/iancoleman/orderedmap v0.0.0-20190318233801-ac98e3ecb4b0 h1:i462o439ZjprVSFSZLZxcsoAe592sZB1rci2Z8j4wdk= github.com/iancoleman/orderedmap v0.0.0-20190318233801-ac98e3ecb4b0/go.mod h1:N0Wam8K1arqPXNWjMo21EXnBPOPp36vB07FNRdD2geA= github.com/invopop/jsonschema v0.4.0 h1:Yuy/unfgCnfV5Wl7H0HgFufp/rlurqPOOuacqyByrws= github.com/invopop/jsonschema v0.4.0/go.mod h1:O9uiLokuu0+MGFlyiaqtWxwqJm41/+8Nj0lD7A36YH0= +github.com/invopop/yaml v0.1.0 h1:YW3WGUoJEXYfzWBjn00zIlrw7brGVD0fUKRYDPAPhrc= +github.com/invopop/yaml v0.1.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -68,6 +69,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -75,14 +78,17 @@ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.1-0.20190311161405-34c6fa2dc709/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= @@ -125,10 +131,10 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/options-schema.json b/options-schema.json index 2b212e0..eac2f41 100644 --- a/options-schema.json +++ b/options-schema.json @@ -6,39 +6,25 @@ "OperationValidationOptions": { "properties": { "validateRequestBody": { - "type": "string", - "enum": ["propagate-error", "print-warning", "ignore"], - "default": "propagate-error" + "type": "integer" }, "validatePathParams": { - "type": "string", - "enum": ["propagate-error", "print-warning", "ignore"], - "default": "propagate-error" + "type": "integer" }, "handleAllPathParams": { - "type": "string", - "enum": ["propagate-error", "print-warning", "ignore"], - "default": "propagate-error" + "type": "integer" }, "validateQueryParams": { - "type": "string", - "enum": ["propagate-error", "print-warning", "ignore"], - "default": "propagate-error" + "type": "integer" }, "handleAllQueryParams": { - "type": "string", - "enum": ["propagate-error", "print-warning", "ignore"], - "default": "propagate-error" + "type": "integer" }, "validateResponses": { - "type": "string", - "enum": ["propagate-error", "print-warning", "ignore"], - "default": "propagate-error" + "type": "integer" }, "handleAllOperationResponses": { - "type": "string", - "enum": ["propagate-error", "print-warning", "ignore"], - "default": "propagate-error" + "type": "integer" } }, "additionalProperties": false, @@ -46,16 +32,11 @@ }, "Options": { "properties": { - "$schema": { - "type": "string" - }, "recoverOnPanic": { "type": "boolean" }, "logLevel": { - "type": "string", - "enum": ["off", "error", "warn", "info"], - "default": "info" + "type": "integer" }, "operationValidations": { "patternProperties": { @@ -69,18 +50,23 @@ "$ref": "#/$defs/OperationValidationOptions" }, "mustHandleAllOperations": { - "type": "string", - "enum": ["propagate-error", "print-warning", "ignore"], - "default": "propagate-error" + "type": "integer" }, "handleAllContentTypes": { - "type": "string", - "enum": ["propagate-error", "print-warning", "ignore"], - "default": "propagate-error" + "type": "integer" + }, + "ExcludeOperations": { + "items": { + "type": "string" + }, + "type": "array" } }, "additionalProperties": false, - "type": "object" + "type": "object", + "required": [ + "ExcludeOperations" + ] } } } \ No newline at end of file diff --git a/options.json b/options.json index 0e0dcd2..c3bfecd 100644 --- a/options.json +++ b/options.json @@ -1,3 +1,3 @@ { - + "$schema": "./options-schema.json" } \ No newline at end of file diff --git a/router/binders.go b/router/binders.go index f483c11..5dfaee2 100644 --- a/router/binders.go +++ b/router/binders.go @@ -1,13 +1,18 @@ package router import ( + "bytes" "fmt" "io" "log" "mime" "net/http" + "net/url" "reflect" + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers" "github.com/gin-gonic/gin/binding" "github.com/julienschmidt/httprouter" @@ -16,8 +21,14 @@ import ( const contentTypeHeader = "Content-Type" +type binder[T any] func(*Context, *T) error + +func nilBinder[T any](*Context, *T) error { + return nil +} + // A request binder takes a Context with its untyped Context.Request and Context.Params and produce a typed Request. -type requestBinder[B, P, Q any] func(ctx *Context) (Request[B, P, Q], error) +type requestBinder[B, P, Q any] func(*Context) (Request[B, P, Q], error) // A response binder takes a Context with its Context.Writer and previous Context.RawResponse to write a typed Response output. type responseBinder[R any] func(*Context, Response[R]) (RawResponse, error) @@ -33,13 +44,13 @@ func requestBinderFactory[B, P, Q any](oa openapi, types requestTypes) requestBi var request = Request[B, P, Q]{ Headers: ctx.Request.Header, } - if err := requestBodyBinder(ctx.Request, &request.Body); err != nil { + if err := requestBodyBinder(ctx, &request.Body); err != nil { return request, newBadRequestErr(ctx, err, InBody) } - if err := pathParamsBinder(ctx.Params, &request.PathParams); err != nil { + if err := pathParamsBinder(ctx, &request.PathParams); err != nil { return request, newBadRequestErr(ctx, err, InPathParams) } - if err := queryParamsBinder(ctx.Request, &request.QueryParams); err != nil { + if err := queryParamsBinder(ctx, &request.QueryParams); err != nil { return request, newBadRequestErr(ctx, err, InQueryParams) } return request, nil @@ -47,17 +58,22 @@ func requestBinderFactory[B, P, Q any](oa openapi, types requestTypes) requestBi } // produce the httpRequest Body binder that can be used in runtime -func requestBodyBinderFactory[B any](requestBodyType reflect.Type, contentTypes ContentTypes) func(*http.Request, *B) error { - if requestBodyType == nilType { - return func(r *http.Request, body *B) error { return nil } +func requestBodyBinderFactory[B any](requestBodyType reflect.Type, contentTypes ContentTypes) binder[B] { + if requestBodyType == utils.NilType { + return nilBinder[B] } - return func(r *http.Request, body *B) error { - contentType, err := requestContentType(r, contentTypes, JSONContentType{}) + return func(ctx *Context, body *B) error { + input, err := validateBodyAndPopulateDefaults(ctx) + if err != nil { + return err + } + + contentType, err := requestContentType(input.Request, contentTypes, JSONContentType{}) if err != nil { return err } - defer func() { _ = r.Body.Close() }() - bodyBytes, err := io.ReadAll(r.Body) + defer func() { _ = input.Request.Body.Close() }() + bodyBytes, err := io.ReadAll(input.Request.Body) if err != nil { return err } @@ -68,17 +84,34 @@ func requestBodyBinderFactory[B any](requestBodyType reflect.Type, contentTypes } } +// validateBodyAndPopulateDefaults validate the request body with the openapi spec and populate the default values. +func validateBodyAndPopulateDefaults(ctx *Context) (*openapi3filter.RequestValidationInput, error) { + input := requestValidationInput(ctx) + if ctx.Operation.RequestBody != nil { + if err := openapi3filter.ValidateRequestBody(ctx.Request.Context(), input, ctx.Operation.RequestBody.Value); err != nil { + return nil, err + } + } + return input, nil +} + // produce the pathParamInValue pathParams binder that can be used in runtime -func pathBinderFactory[P any](pathParamsType reflect.Type) func(*httprouter.Params, *P) error { - if pathParamsType == nilType { - return func(params *httprouter.Params, body *P) error { return nil } +func pathBinderFactory[P any](pathParamsType reflect.Type) binder[P] { + if pathParamsType == utils.NilType { + return nilBinder[P] } - return func(params *httprouter.Params, pathParams *P) error { + return func(ctx *Context, target *P) error { + defaults, err := validateParamsAndPopulateDefaults(ctx, "path") + if err != nil { + return err + } + m := make(map[string][]string) - for _, v := range *params { - m[v.Key] = []string{v.Value} + for k, v := range defaults.PathParams { + m[k] = []string{v} } - if err := binding.Uri.BindUri(m, pathParams); err != nil { + + if err = binding.Uri.BindUri(m, target); err != nil { return err } return nil @@ -86,16 +119,17 @@ func pathBinderFactory[P any](pathParamsType reflect.Type) func(*httprouter.Para } // produce the queryParamInValue pathParams binder that can be used in runtime -func queryBinderFactory[Q any](queryParamsType reflect.Type) func(*http.Request, *Q) error { - if queryParamsType == nilType { - return func(*http.Request, *Q) error { return nil } +func queryBinderFactory[Q any](queryParamsType reflect.Type) binder[Q] { + if queryParamsType == utils.NilType { + return nilBinder[Q] } - paramFields := structKeys(queryParamsType, "form") + paramFields := utils.StructKeys(queryParamsType, "form") nonArrayParams := utils.NewSet[string]() for param, paramType := range paramFields { - if paramType.Type.Kind() == reflect.Slice || - paramType.Type.Kind() == reflect.Array || - (paramType.Type.Kind() == reflect.Pointer && + kind := paramType.Type.Kind() + if kind == reflect.Slice || + kind == reflect.Array || + (kind == reflect.Pointer && (paramType.Type.Elem().Kind() == reflect.Slice || paramType.Type.Elem().Kind() == reflect.Array)) { continue @@ -103,11 +137,17 @@ func queryBinderFactory[Q any](queryParamsType reflect.Type) func(*http.Request, nonArrayParams.Add(param) } - return func(r *http.Request, queryParams *Q) error { - if err := binding.Query.Bind(r, queryParams); err != nil { + return func(ctx *Context, queryParams *Q) error { + defaults, err := validateParamsAndPopulateDefaults(ctx, "query") + if err != nil { return err } - for param, values := range r.URL.Query() { + + if err = binding.Query.Bind(defaults.Request, queryParams); err != nil { + return err + } + + for param, values := range defaults.QueryParams { if nonArrayParams.Has(param) && len(values) > 1 { return fmt.Errorf("multiple values received for query param %s", param) } @@ -143,12 +183,33 @@ func responseBinderFactory[R any](responses handlerResponses, contentTypes Conte } bindResponseHeaders(ctx.Writer, r) ctx.Writer.WriteHeader(r.status) - _, err = ctx.Writer.Write(responseBytes) ctx.RawResponse.Status = r.status ctx.RawResponse.ContentType = r.contentType ctx.RawResponse.Body = responseBytes ctx.RawResponse.Headers = r.headers - return *ctx.RawResponse, err + + if _, err = ctx.Writer.Write(responseBytes); err != nil { + return *ctx.RawResponse, err + } + + validateResponse(ctx, r, responseBytes) + + return *ctx.RawResponse, nil + } +} + +// validateResponse validates the response against the spec. It logs a warning if the response violates the spec. +func validateResponse[R any](ctx *Context, r Response[R], responseBytes []byte) { + input := &openapi3filter.ResponseValidationInput{ + RequestValidationInput: requestValidationInput(ctx), + Status: r.status, + Header: r.headers, + Body: io.NopCloser(bytes.NewReader(responseBytes)), + Options: openapi3filter.DefaultOptions, + } + + if err := openapi3filter.ValidateResponse(ctx.Request.Context(), input); err != nil { + log.Printf("[WARNING] %s. response violates the spec\n", err) } } @@ -195,3 +256,50 @@ func responseContentType(responseContentType string, supportedTypes ContentTypes } return defaultContentType, fmt.Errorf("%w: %s", UnsupportedResponseContentTypeErr, responseContentType) } + +func validateParamsAndPopulateDefaults(ctx *Context, in string) (*openapi3filter.RequestValidationInput, error) { + input := requestValidationInput(ctx) + parameters := utils.Filter(utils.Map(ctx.Operation.Parameters, func(p *openapi3.ParameterRef) *openapi3.Parameter { + return p.Value + }), func(p *openapi3.Parameter) bool { return p.In == in }) + + for _, param := range parameters { + if err := openapi3filter.ValidateParameter(ctx.Request.Context(), input, param); err != nil { + return nil, err + } + } + + // after processing params input is populated with defaults + return input, nil +} + +func requestValidationInput(ctx *Context) *openapi3filter.RequestValidationInput { + input := openapi3filter.RequestValidationInput{ + Request: &http.Request{}, + PathParams: make(map[string]string), + QueryParams: url.Values{}, + Options: openapi3filter.DefaultOptions, + Route: &routers.Route{ + Operation: ctx.Operation.Operation, + }, + ParamDecoder: nil, + } + + if ctx.Request != nil { + input.Request = ctx.Request + if ctx.Request.URL != nil { + input.QueryParams = ctx.Request.URL.Query() + } + } + + if ctx.Params != nil { + input.PathParams = utils.FromEntries(utils.Map(*ctx.Params, func(p httprouter.Param) utils.Entry[string, string] { + return utils.Entry[string, string]{ + Key: p.Key, + Value: p.Value, + } + })) + } + + return &input +} diff --git a/router/binders_test.go b/router/binders_test.go index 8ca3027..f78866f 100644 --- a/router/binders_test.go +++ b/router/binders_test.go @@ -1,17 +1,19 @@ package router import ( - "bytes" "errors" "io" "net/http" - "net/url" + "net/http/httptest" "reflect" "testing" + "github.com/getkin/kin-openapi/openapi3" "github.com/julienschmidt/httprouter" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/piiano/cellotape/router/utils" ) func TestResponseContentType(t *testing.T) { @@ -33,11 +35,7 @@ type StructType struct { func TestQueryBinderFactory(t *testing.T) { queryBinder := queryBinderFactory[StructType](reflect.TypeOf(StructType{})) var params StructType - requestURL, err := url.Parse("http:0.0.0.0:90/abc?Foo=42") - require.NoError(t, err) - err = queryBinder(&http.Request{ - URL: requestURL, - }, ¶ms) + err := queryBinder(testContext(withURL(t, "http:0.0.0.0:90/abc?Foo=42")), ¶ms) require.NoError(t, err) assert.Equal(t, StructType{Foo: 42}, params) } @@ -49,11 +47,7 @@ type StructWithArrayType struct { func TestQueryBinderFactoryWithArrayType(t *testing.T) { queryBinder := queryBinderFactory[StructWithArrayType](reflect.TypeOf(StructWithArrayType{})) var params StructWithArrayType - requestURL, err := url.Parse("http:0.0.0.0:90/abc?Foo=42&Foo=6&Foo=7") - require.NoError(t, err) - err = queryBinder(&http.Request{ - URL: requestURL, - }, ¶ms) + err := queryBinder(testContext(withURL(t, "http:0.0.0.0:90/abc?Foo=42&Foo=6&Foo=7")), ¶ms) require.NoError(t, err) assert.Equal(t, StructWithArrayType{Foo: []int{42, 6, 7}}, params) } @@ -61,33 +55,24 @@ func TestQueryBinderFactoryWithArrayType(t *testing.T) { func TestQueryBinderFactoryMultipleParamToNonArrayError(t *testing.T) { queryBinder := queryBinderFactory[StructType](reflect.TypeOf(StructType{})) var params StructType - requestURL, err := url.Parse("http:0.0.0.0:90/abc?Foo=42&Foo=6&Foo=7") - require.NoError(t, err) - err = queryBinder(&http.Request{ - URL: requestURL, - }, ¶ms) + err := queryBinder(testContext(withURL(t, "http:0.0.0.0:90/abc?Foo=42&Foo=6&Foo=7")), ¶ms) require.Error(t, err) } func TestQueryBinderFactoryError(t *testing.T) { queryBinder := queryBinderFactory[StructType](reflect.TypeOf(StructType{})) var params StructType - requestURL, err := url.Parse("http:0.0.0.0:90/abc?Foo=abc") - require.NoError(t, err) - err = queryBinder(&http.Request{ - URL: requestURL, - }, ¶ms) - + err := queryBinder(testContext(withURL(t, "http:0.0.0.0:90/abc?Foo=abc")), ¶ms) require.Error(t, err) } func TestPathBinderFactory(t *testing.T) { pathBinder := pathBinderFactory[StructType](reflect.TypeOf(StructType{})) var params StructType - err := pathBinder(&httprouter.Params{{ + err := pathBinder(testContext(withParams(&httprouter.Params{{ Key: "Foo", Value: "42", - }}, ¶ms) + }})), ¶ms) require.NoError(t, err) assert.Equal(t, StructType{Foo: 42}, params) } @@ -95,19 +80,32 @@ func TestPathBinderFactory(t *testing.T) { func TestPathBinderFactoryError(t *testing.T) { pathBinder := pathBinderFactory[StructType](reflect.TypeOf(StructType{})) var params StructType - err := pathBinder(&httprouter.Params{{ + err := pathBinder(testContext(withParams(&httprouter.Params{{ Key: "Foo", Value: "bar", - }}, ¶ms) + }})), ¶ms) require.Error(t, err) } func TestRequestBodyBinderFactory(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&http.Request{ - Body: io.NopCloser(bytes.NewBuffer([]byte("42"))), - }, ¶m) + err := requestBodyBinder(testContext(withBody("42")), ¶m) + require.NoError(t, err) + assert.Equal(t, 42, param) +} + +func TestRequestBodyBinderFactoryWithSchema(t *testing.T) { + testOp := openapi3.NewOperation() + testOp.RequestBody = &openapi3.RequestBodyRef{ + Value: openapi3.NewRequestBody().WithJSONSchema(openapi3.NewIntegerSchema()), + } + requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) + var param int + err := requestBodyBinder(testContext( + withBody("42"), + withHeader("Content-Type", "application/json"), + withOperation(testOp)), ¶m) require.NoError(t, err) assert.Equal(t, 42, param) } @@ -115,9 +113,8 @@ func TestRequestBodyBinderFactory(t *testing.T) { func TestRequestBodyBinderFactoryError(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&http.Request{ - Body: io.NopCloser(bytes.NewBuffer([]byte(`"foo"`))), - }, ¶m) + + err := requestBodyBinder(testContext(withBody(`"foo"`)), ¶m) require.Error(t, err) } @@ -132,29 +129,27 @@ func (r readerWithError) Read(_ []byte) (int, error) { func TestRequestBodyBinderFactoryReaderError(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&http.Request{ - Body: io.NopCloser(readerWithError(`42`)), - }, ¶m) + err := requestBodyBinder(testContext( + withBodyReader(io.NopCloser(readerWithError(`42`)))), ¶m) require.Error(t, err) } func TestRequestBodyBinderFactoryContentTypeError(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&http.Request{ - Header: http.Header{"Content-Type": {"no-such-content-type"}}, - Body: io.NopCloser(bytes.NewBuffer([]byte(`42`))), - }, ¶m) + + err := requestBodyBinder(testContext( + withBody("42"), + withHeader("Content-Type", "no-such-content-type")), ¶m) require.Error(t, err) } func TestRequestBodyBinderFactoryContentTypeWithCharset(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&http.Request{ - Header: http.Header{"Content-Type": {"application/json; charset=utf-8"}}, - Body: io.NopCloser(bytes.NewBuffer([]byte("42"))), - }, ¶m) + err := requestBodyBinder(testContext( + withBody("42"), + withHeader("Content-Type", "application/json; charset=utf-8")), ¶m) require.NoError(t, err) assert.Equal(t, 42, param) } @@ -162,20 +157,18 @@ func TestRequestBodyBinderFactoryContentTypeWithCharset(t *testing.T) { func TestRequestBodyBinderFactoryInvalidContentType(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&http.Request{ - Header: http.Header{"Content-Type": {"invalid content type"}}, - Body: io.NopCloser(bytes.NewBuffer([]byte("42"))), - }, ¶m) + err := requestBodyBinder(testContext( + withBody("42"), + withHeader("Content-Type", "invalid content type")), ¶m) require.Error(t, err) } func TestRequestBodyBinderFactoryContentTypeAnyWithCharset(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&http.Request{ - Header: http.Header{"Content-Type": {"*/*; charset=utf-8"}}, - Body: io.NopCloser(bytes.NewBuffer([]byte("42"))), - }, ¶m) + err := requestBodyBinder(testContext( + withBody("42"), + withHeader("Content-Type", "*/*; charset=utf-8")), ¶m) require.NoError(t, err) assert.Equal(t, 42, param) } @@ -193,12 +186,11 @@ type CollidingFieldsParams struct { func TestBindingEmbeddedQueryParamsCollidingFields(t *testing.T) { requestBodyBinder := queryBinderFactory[CollidingFieldsParams](reflect.TypeOf(CollidingFieldsParams{})) - requestURL, err := url.Parse("http://http:0.0.0.0:8080/path?param1=foo¶m2=bar") - require.NoError(t, err) var param CollidingFieldsParams - err = requestBodyBinder(&http.Request{ - URL: requestURL, - }, ¶m) + + ctx := testContext(withURL(t, "http://http:0.0.0.0:8080/path?param1=foo¶m2=bar")) + + err := requestBodyBinder(ctx, ¶m) require.NoError(t, err) require.Equal(t, "foo", param.CollidingFieldsParam1.Value) require.Equal(t, "bar", param.CollidingFieldsParam2.Value) @@ -217,13 +209,56 @@ type CollidingParams struct { func TestBindingEmbeddedQueryParamsCollidingParams(t *testing.T) { requestBodyBinder := queryBinderFactory[CollidingParams](reflect.TypeOf(CollidingParams{})) - requestURL, err := url.Parse("http://http:0.0.0.0:8080/path?param1=42") - require.NoError(t, err) + var param CollidingParams - err = requestBodyBinder(&http.Request{ - URL: requestURL, - }, ¶m) + err := requestBodyBinder(testContext( + withURL(t, "http://http:0.0.0.0:8080/path?param1=42")), ¶m) require.NoError(t, err) require.Equal(t, "42", param.CollidingParamString.Value) require.Equal(t, 42, param.CollidingParamInt.Value) } + +type errWriter struct{} + +func (e errWriter) Header() http.Header { return http.Header{} } +func (e errWriter) WriteHeader(int) {} +func (e errWriter) Write(i []byte) (int, error) { + return 0, errors.New("error") +} + +func TestErrOnWriterError(t *testing.T) { + type R = OKResponse[string] + responses := extractResponses(utils.GetType[R]()) + binder := responseBinderFactory[R](responses, DefaultContentTypes()) + response := SendOK(R{OK: "foo"}).ContentType("unknown") + + testCases := []struct { + name string + writer http.ResponseWriter + assertion func(require.TestingT, error, ...any) + }{ + { + name: "writer error", + writer: errWriter{}, + assertion: require.Error, + }, + { + name: "proper writer", + writer: httptest.NewRecorder(), + assertion: require.NoError, + }, + } + testOp := openapi3.NewOperation() + testOp.AddResponse(200, openapi3.NewResponse().WithJSONSchema(openapi3.NewStringSchema())) + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + ctx := testContext( + withOperation(testOp), + withResponseWriter(test.writer), + ) + _, err := binder(ctx, response) + test.assertion(t, err) + }) + } +} diff --git a/router/content_types.go b/router/content_types.go index 0a0214f..36ae3dc 100644 --- a/router/content_types.go +++ b/router/content_types.go @@ -104,7 +104,13 @@ func (t JSONContentType) Encode(value any) ([]byte, error) { return json.Mars func (t JSONContentType) Decode(data []byte, value any) error { return json.Unmarshal(data, value) } func (t JSONContentType) ValidateTypeSchema( logger utils.Logger, level utils.LogLevel, goType reflect.Type, schema openapi3.Schema) error { - return schema_validator.NewTypeSchemaValidator(logger, level, goType, schema).Validate() + validator := schema_validator.NewTypeSchemaValidator(goType, schema) + err := validator.Validate() + for _, errMessage := range validator.Errors() { + logger.Log(level, errMessage) + } + + return err } func DefaultContentTypes() ContentTypes { diff --git a/router/content_types_test.go b/router/content_types_test.go index 65632f8..76b9f9e 100644 --- a/router/content_types_test.go +++ b/router/content_types_test.go @@ -1,8 +1,15 @@ package router import ( + "bytes" + "errors" + "io" + "net/http" + "net/http/httptest" + "net/url" "reflect" "testing" + "testing/iotest" "github.com/getkin/kin-openapi/openapi3" "github.com/stretchr/testify/assert" @@ -67,6 +74,106 @@ func TestContentTypeMime(t *testing.T) { } } +type foo struct { + Foo string `json:"foo"` +} +type fooContentType struct { + shouldErr bool +} + +func (f fooContentType) Mime() string { return "foo" } + +func (f fooContentType) Encode(a any) ([]byte, error) { + return []byte(a.(foo).Foo), nil +} + +func (f fooContentType) Decode(bytes []byte, a any) error { + if f.shouldErr { + return errors.New("foo decode error") + } + switch typedValue := a.(type) { + case *foo: + (*typedValue).Foo = string(bytes) + case *any: + *typedValue = string(bytes) + } + return nil +} + +func (f fooContentType) ValidateTypeSchema(_ utils.Logger, _ utils.LogLevel, _ reflect.Type, _ openapi3.Schema) error { + return nil +} + +func TestValidationsWithCustomContentType(t *testing.T) { + testSpec, err := NewSpecFromData([]byte(` +paths: + /test: + post: + operationId: test + requestBody: + content: + foo: + schema: + type: string + responses: + '200': + description: ok +`)) + require.NoError(t, err) + + testCases := []struct { + contentType ContentType + bodyReader io.ReadCloser + shouldErr bool + }{ + { + contentType: fooContentType{}, + bodyReader: io.NopCloser(bytes.NewBufferString("bar")), + }, + { + contentType: fooContentType{shouldErr: true}, + bodyReader: io.NopCloser(bytes.NewBufferString("bar")), + shouldErr: true, + }, + { + contentType: fooContentType{}, + bodyReader: io.NopCloser(iotest.ErrReader(errors.New("failed reading body"))), + shouldErr: true, + }, + } + + for _, test := range testCases { + var calledWithBody *foo + var badRequestErr error + router := NewOpenAPIRouter(testSpec). + WithContentType(test.contentType). + WithOperation("test", HandlerFunc[foo, Nil, Nil, OKResponse[Nil]](func(_ *Context, r Request[foo, Nil, Nil]) (Response[OKResponse[Nil]], error) { + calledWithBody = &r.Body + return SendOK(OKResponse[Nil]{}), nil + }), ErrorHandler(func(_ *Context, err error) (Response[any], error) { + badRequestErr = err + return Response[any]{}, nil + })) + handler, err := router.AsHandler() + require.NoError(t, err) + + handler.ServeHTTP(&httptest.ResponseRecorder{}, &http.Request{ + Method: http.MethodPost, + URL: &url.URL{Path: "/test"}, + Header: http.Header{"Content-Type": []string{"foo"}}, + Body: test.bodyReader, + }) + + if test.shouldErr { + //require.Nil(t, calledWithBody) + require.Error(t, badRequestErr) + } else { + assert.Equal(t, foo{Foo: "bar"}, *calledWithBody) + require.NoError(t, badRequestErr) + } + } +} + func TestOctetStreamContentTypeBytesSlice(t *testing.T) { encodedBytes, err := OctetStreamContentType{}.Encode([]byte("foo")) require.NoError(t, err) diff --git a/router/core.go b/router/core.go index 007fe54..d0d1d5c 100644 --- a/router/core.go +++ b/router/core.go @@ -1,27 +1,33 @@ package router import ( + "encoding/json" + "errors" + "io" "log" "net/http" "regexp" "runtime/debug" + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" "github.com/julienschmidt/httprouter" "github.com/piiano/cellotape/router/utils" ) func createMainRouterHandler(oa *openapi) (http.Handler, error) { + // Customize the error message returned by the kin-openapi library to be more user-friendly. + openapi3filter.DefaultOptions.WithCustomSchemaErrorFunc(func(err *openapi3.SchemaError) string { + return err.Reason + }) flatOperations := flattenOperations(oa.group) if err := validateOpenAPIRouter(oa, flatOperations); err != nil { return nil, err } router := httprouter.New() router.HandleMethodNotAllowed = false - ////router.PanicHandler = nil - //router.PanicHandler = func(writer http.ResponseWriter, request *http.Request, i interface{}) { - // log.Println("http-router handler") - //} + logger := oa.logger() pathParamsMatcher := regexp.MustCompile(`\{([^/}]*)}`) @@ -34,9 +40,55 @@ func createMainRouterHandler(oa *openapi) (http.Handler, error) { router.Handle(specOp.Method, path, httpRouterHandler) logger.Infof("register handler for operation %q - %s %s", flatOp.id, specOp.Method, specOp.Path) } + + // For Kin-openapi to be able to validate a request and set default values it need to know how to decode and encode + // the request body for any supported content type. + for _, contentType := range oa.contentTypes { + mimeType := contentType.Mime() + if openapi3filter.RegisteredBodyEncoder(mimeType) == nil { + openapi3filter.RegisterBodyEncoder(contentType.Mime(), contentType.Encode) + } + if openapi3filter.RegisteredBodyDecoder(mimeType) == nil { + openapi3filter.RegisterBodyDecoder(contentType.Mime(), createDecoder(contentType)) + } + } + return router, nil } +func createDecoder(contentType ContentType) func(reader io.Reader, _ http.Header, schema *openapi3.SchemaRef, enc openapi3filter.EncodingFn) (any, error) { + return func(reader io.Reader, _ http.Header, schema *openapi3.SchemaRef, enc openapi3filter.EncodingFn) (any, error) { + bytes, err := io.ReadAll(reader) + if err != nil { + return nil, err + } + + var target any + if err = contentType.Decode(bytes, &target); err != nil { + return nil, err + } + + // For kin-openapi to be able to validate a request it requires that the decoded value will on of + // the values received when decoding JSON to any. + // e.g. any, []any, []map[string]any, etc. + // + // After using the custom decoder we get a value of the type of the target struct. + // To overcome this we marshal the target to JSON and then unmarshal it to any. + + jsonBytes, err := json.Marshal(target) + if err != nil { + return nil, err + } + + var jsonValue any + if err = json.Unmarshal(jsonBytes, &jsonValue); err != nil { + return nil, err + } + + return jsonValue, nil + } +} + func (oa *openapi) logger() utils.Logger { return utils.NewLoggerWithLevel(oa.options.LogOutput, oa.options.LogLevel) } @@ -59,6 +111,16 @@ func chainHandlers(oa openapi, handlers ...handler) (head BoundHandlerFunc) { for i := len(handlers) - 1; i >= 0; i-- { next = handlers[i].handlerFunc.handlerFactory(oa, next) } + next = ErrorHandler(func(c *Context, err error) (Response[any], error) { + var badRequestError BadRequestErr + if err != nil && c.RawResponse.Status == 0 && errors.As(err, &badRequestError) { + c.Writer.Header().Add("Content-Type", "text/plain") + c.Writer.WriteHeader(400) + _, writeErr := c.Writer.Write([]byte(err.Error())) + return Error[any](writeErr) + } + return Error[any](err) + }).handlerFactory(oa, next) return next } @@ -74,10 +136,9 @@ func asHttpRouterHandler(oa openapi, specOp SpecOperation, head BoundHandlerFunc Params: ¶ms, RawResponse: &RawResponse{Status: 0}, } + _, err := head(ctx) if err != nil || ctx.RawResponse.Status == 0 { - log.Println("unhandled error") - log.Println(err) writer.WriteHeader(500) return } diff --git a/router/core_test.go b/router/core_test.go index 5a8455c..87564bd 100644 --- a/router/core_test.go +++ b/router/core_test.go @@ -7,6 +7,9 @@ import ( "github.com/julienschmidt/httprouter" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/piiano/cellotape/router/utils" ) func TestDefaultRecoverFromError(t *testing.T) { @@ -27,3 +30,26 @@ func TestName(t *testing.T) { assert.Equal(t, 500, writer.Code) } + +func TestFailStartOnValidationError(t *testing.T) { + _, err := createMainRouterHandler(&openapi{ + spec: NewSpec(), + options: DefaultOptions(), + group: group{ + operations: []operation{ + { + id: "test", + handler: handler{ + request: requestTypes{ + requestBody: utils.NilType, + pathParams: utils.NilType, + queryParams: utils.NilType, + }, + }, + }, + }, + }, + }) + + require.Error(t, err) +} diff --git a/router/error_messages.go b/router/error_messages.go index 57aae96..a8f2e3f 100644 --- a/router/error_messages.go +++ b/router/error_messages.go @@ -57,5 +57,5 @@ func paramMissingImplementationInChain(in string, name string, operationId strin return fmt.Sprintf("%s param %q exists on the spec for operation %q but not declared on any handler", in, name, operationId) } func anExcludedOperationIsImplemented(operationId string) string { - return fmt.Sprintf("the excluded operation %s is implemented by a handler", operationId) + return fmt.Sprintf("the excluded operation %q is implemented by a handler", operationId) } diff --git a/router/error_messages_test.go b/router/error_messages_test.go index 8527097..72a8013 100644 --- a/router/error_messages_test.go +++ b/router/error_messages_test.go @@ -80,4 +80,7 @@ func TestErrorMessages(t *testing.T) { `query param "bar" exists on the spec for operation "foo" but not declared on any handler`, paramMissingImplementationInChain("query", "bar", "foo")) + assert.Equal(t, + `the excluded operation "foo" is implemented by a handler`, + anExcludedOperationIsImplemented("foo")) } diff --git a/router/handler.go b/router/handler.go index a1cb3c8..6b9403a 100644 --- a/router/handler.go +++ b/router/handler.go @@ -2,9 +2,10 @@ package router import ( "net/http" - "reflect" "github.com/julienschmidt/httprouter" + + "github.com/piiano/cellotape/router/utils" ) // Handler described the HandlerFunc in a non parametrized way. @@ -95,29 +96,18 @@ func NewHandler[B, P, Q, R any](h HandlerFunc[B, P, Q, R]) Handler { return h } -// getType returns reflect.Type of the generic parameter it receives. -func getType[T any]() reflect.Type { return reflect.TypeOf(new(T)).Elem() } - -// Nil represents an empty type. -// You can use it with the HandlerFunc generic parameters to declare no Request with no request body, no path or query -// params, or responses with no response body. -type Nil *uintptr - -// nilType represent the type of Nil. -var nilType = getType[Nil]() - // requestTypes extracts the request types defined by the HandlerFunc func (h HandlerFunc[B, P, Q, R]) requestTypes() requestTypes { return requestTypes{ - requestBody: getType[B](), - pathParams: getType[P](), - queryParams: getType[Q](), + requestBody: utils.GetType[B](), + pathParams: utils.GetType[P](), + queryParams: utils.GetType[Q](), } } // responseTypes extracts the responses defined by the HandlerFunc and returns handlerResponses func (h HandlerFunc[B, P, Q, R]) responseTypes() handlerResponses { - return extractResponses(getType[R]()) + return extractResponses(utils.GetType[R]()) } // sourcePosition finds the sourcePosition of the HandlerFunc function for printing meaningful messages during validations diff --git a/router/helpers.go b/router/helpers.go index 90cc019..974652d 100644 --- a/router/helpers.go +++ b/router/helpers.go @@ -2,8 +2,13 @@ package router import ( "net/http" + + "github.com/piiano/cellotape/router/utils" ) +type Nil = utils.Nil +type MultiType[T any] utils.MultiType[T] + // Send constructs a new Response. func Send[R any](response R, headers ...http.Header) Response[R] { aggregatedHeaders := make(http.Header, 0) @@ -62,7 +67,7 @@ func Error[R any](err error) (Response[R], error) { // RawHandler adds a handler that doesn't define any type information. func RawHandler(f func(c *Context) error) Handler { - return NewHandler(func(c *Context, _ Request[Nil, Nil, Nil]) (Response[any], error) { + return NewHandler(func(c *Context, _ Request[utils.Nil, utils.Nil, utils.Nil]) (Response[any], error) { return Response[any]{}, f(c) }) } diff --git a/router/helpers_test.go b/router/helpers_test.go index b36c1d7..e69d7fe 100644 --- a/router/helpers_test.go +++ b/router/helpers_test.go @@ -6,6 +6,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/piiano/cellotape/router/utils" ) type OKResponse[R any] struct { @@ -112,9 +114,9 @@ func TestRawHandler(t *testing.T) { assert.Equal(t, []byte("test"), response.Body) return nil }) - assert.Equal(t, nilType, rawHandler.requestTypes().requestBody) - assert.Equal(t, nilType, rawHandler.requestTypes().pathParams) - assert.Equal(t, nilType, rawHandler.requestTypes().queryParams) + assert.Equal(t, utils.NilType, rawHandler.requestTypes().requestBody) + assert.Equal(t, utils.NilType, rawHandler.requestTypes().pathParams) + assert.Equal(t, utils.NilType, rawHandler.requestTypes().queryParams) assert.Len(t, rawHandler.responseTypes(), 0) rawResponse := RawResponse{ Status: 200, @@ -125,7 +127,7 @@ func TestRawHandler(t *testing.T) { handlerFunc := rawHandler.handlerFactory(openapi{}, func(c *Context) (RawResponse, error) { return rawResponse, nil }) - resp, err := handlerFunc(&Context{Request: &http.Request{}, RawResponse: &RawResponse{}}) + resp, err := handlerFunc(testContext()) require.ErrorIs(t, err, UnsupportedResponseStatusErr) assert.Zero(t, resp) diff --git a/router/model.go b/router/model.go index ee23801..9ecc17b 100644 --- a/router/model.go +++ b/router/model.go @@ -51,11 +51,11 @@ type handler struct { // requestTypes described the parameter types provided in the Request input of a handler function type requestTypes struct { - // requestBody is the type of the Body parameter. type is nilType if there is no httpRequest Body + // requestBody is the type of the Body parameter. type is NilType if there is no httpRequest Body requestBody reflect.Type - // pathParams is the type of the Body parameter. type is nilType if there is no pathParamInValue pathParams + // pathParams is the type of the Body parameter. type is NilType if there is no pathParamInValue pathParams pathParams reflect.Type - // queryParams is the type of the Body parameter. type is nilType if there is no queryParamInValue pathParams + // queryParams is the type of the Body parameter. type is NilType if there is no queryParamInValue pathParams queryParams reflect.Type } diff --git a/router/operation_handler_test.go b/router/operation_handler_test.go index d220164..1291cb4 100644 --- a/router/operation_handler_test.go +++ b/router/operation_handler_test.go @@ -10,21 +10,25 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/piiano/cellotape/router/utils" ) func TestHandlerFuncTypeExtraction(t *testing.T) { - fn := HandlerFunc[Nil, Nil, Nil, Nil](func(*Context, Request[Nil, Nil, Nil]) (Response[Nil], error) { return Response[Nil]{}, nil }) + fn := HandlerFunc[utils.Nil, utils.Nil, utils.Nil, utils.Nil](func(*Context, Request[utils.Nil, utils.Nil, utils.Nil]) (Response[utils.Nil], error) { + return Response[utils.Nil]{}, nil + }) types := fn.requestTypes() - assert.Equal(t, types.requestBody, nilType) - assert.Equal(t, types.pathParams, nilType) - assert.Equal(t, types.queryParams, nilType) + assert.Equal(t, types.requestBody, utils.NilType) + assert.Equal(t, types.pathParams, utils.NilType) + assert.Equal(t, types.queryParams, utils.NilType) } func TestRouterAsHandler(t *testing.T) { type responses struct { Answer int `status:"200"` } - fn := HandlerFunc[Nil, Nil, Nil, responses](func(*Context, Request[Nil, Nil, Nil]) (Response[responses], error) { + fn := HandlerFunc[utils.Nil, utils.Nil, utils.Nil, responses](func(*Context, Request[utils.Nil, utils.Nil, utils.Nil]) (Response[responses], error) { return SendOKJSON(responses{Answer: 42}), nil }) spec, err := NewSpecFromData([]byte(` diff --git a/router/options.go b/router/options.go index cd8a696..2781a78 100644 --- a/router/options.go +++ b/router/options.go @@ -8,6 +8,15 @@ import ( "github.com/piiano/cellotape/router/utils" ) +type LogLevel = utils.LogLevel + +const ( + LogLevelError = utils.Error + LogLevelWarn = utils.Warn + LogLevelInfo = utils.Info + LogLevelOff = utils.Off +) + // Behaviour defines a possible behaviour for a validation error. // Possible values are PropagateError, PrintWarning and Ignore. // @@ -75,7 +84,7 @@ type Options struct { // By default, LogLevel is set to utils.Info to print all info to the log. // The router prints to the log only during initialization to show validation errors, warnings and useful info. // No printing is done after initialization. - LogLevel utils.LogLevel `json:"logLevel,omitempty"` + LogLevel LogLevel `json:"logLevel,omitempty"` // LogOutput defines where to write the outputs too. // By default, it is set to write to os.Stderr. @@ -159,7 +168,7 @@ type SchemaValidationOptions struct { func DefaultOptions() Options { return Options{ RecoverOnPanic: true, - LogLevel: utils.Info, + LogLevel: LogLevelInfo, LogOutput: os.Stderr, DefaultOperationValidation: OperationValidationOptions{ ValidateRequestBody: PropagateError, diff --git a/router/options_test.go b/router/options_test.go index 81c989c..f6ce38a 100644 --- a/router/options_test.go +++ b/router/options_test.go @@ -2,7 +2,7 @@ package router import ( "encoding/json" - "log" + "os" "testing" "github.com/invopop/jsonschema" @@ -13,11 +13,13 @@ import ( ) func TestSchema(t *testing.T) { - schema := jsonschema.Reflect(&Options{}) - bytes, _ := schema.MarshalJSON() - log.Println(string(bytes)) + bytes, _ := json.MarshalIndent(schema, "", " ") + + schemaFile, err := os.ReadFile("../options-schema.json") + assert.NoError(t, err) + assert.Equal(t, string(schemaFile), string(bytes)) } func TestBehaviourZeroValue(t *testing.T) { diff --git a/router/response.go b/router/response.go index 8e0b66f..44a37a5 100644 --- a/router/response.go +++ b/router/response.go @@ -5,6 +5,8 @@ import ( "math/bits" "reflect" "strconv" + + "github.com/piiano/cellotape/router/utils" ) const statusTag = "status" @@ -44,7 +46,7 @@ func extractResponses(t reflect.Type) handlerResponses { status: status, fieldIndex: field.Index, responseType: field.Type, - isNilType: field.Type == nilType, + isNilType: field.Type == utils.NilType, } } return responseTypesMap diff --git a/router/response_test.go b/router/response_test.go index 0dfe97b..f212627 100644 --- a/router/response_test.go +++ b/router/response_test.go @@ -7,6 +7,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/piiano/cellotape/router/utils" ) type EmbeddedResponse struct { @@ -15,7 +17,7 @@ type EmbeddedResponse struct { func TestExtractResponses(t *testing.T) { tableTest(t, extractResponses, []testCase[reflect.Type, handlerResponses]{ - {in: nilType, out: handlerResponses{}}, + {in: utils.NilType, out: handlerResponses{}}, {in: reflect.TypeOf(struct{}{}), out: handlerResponses{}}, {in: reflect.TypeOf(struct{ ok string }{}), out: handlerResponses{}}, {in: reflect.TypeOf(struct { diff --git a/router/runtime_errors.go b/router/runtime_errors.go index 2d75dc5..dad9ef4 100644 --- a/router/runtime_errors.go +++ b/router/runtime_errors.go @@ -3,6 +3,8 @@ package router import ( "errors" "fmt" + + "github.com/piiano/cellotape/router/utils" ) // Runtime errors causes @@ -71,8 +73,8 @@ func (e BadRequestErr) Unwrap() error { // ErrorHandler allows providing a handler function that can handle errors occurred in the handlers chain. // This type of handler is particularly useful for handling BadRequestErr caused by a request binding errors and // translate it to an HTTP response. -func ErrorHandler[R any](errHandler func(c *Context, err error) (Response[R], error)) HandlerFunc[Nil, Nil, Nil, R] { - return func(c *Context, _ Request[Nil, Nil, Nil]) (Response[R], error) { +func ErrorHandler[R any](errHandler func(c *Context, err error) (Response[R], error)) HandlerFunc[utils.Nil, utils.Nil, utils.Nil, R] { + return func(c *Context, _ Request[utils.Nil, utils.Nil, utils.Nil]) (Response[R], error) { _, err := c.Next() if err != nil { return errHandler(c, err) diff --git a/router/runtime_errors_test.go b/router/runtime_errors_test.go index 39690de..3cecaca 100644 --- a/router/runtime_errors_test.go +++ b/router/runtime_errors_test.go @@ -1,26 +1,88 @@ package router import ( + "bytes" "errors" + "io" "net/http" "net/http/httptest" + "net/url" "reflect" "testing" + "github.com/getkin/kin-openapi/openapi3" "github.com/julienschmidt/httprouter" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/piiano/cellotape/router/utils" ) -var testContext = func() *Context { - return &Context{ +type contextModifier func(*Context) + +func testContext(modifiers ...contextModifier) *Context { + ctx := &Context{ + Operation: SpecOperation{ + Operation: openapi3.NewOperation(), + }, RawResponse: &RawResponse{}, Request: &http.Request{ + URL: &url.URL{}, Header: http.Header{}, }, Writer: &httptest.ResponseRecorder{}, Params: &httprouter.Params{}, } + for _, modifier := range modifiers { + modifier(ctx) + } + return ctx +} + +func withURL(t *testing.T, urlString string) contextModifier { + urlValue, err := url.Parse(urlString) + require.NoError(t, err) + return func(ctx *Context) { + ctx.Request.URL = urlValue + } +} + +func withBody(body string) contextModifier { + bodyReader := io.NopCloser(bytes.NewBuffer([]byte(body))) + + return withBodyReader(bodyReader) +} + +func withBodyReader(bodyReader io.ReadCloser) contextModifier { + return func(ctx *Context) { + ctx.Request.Body = bodyReader + } +} + +func withParams(params *httprouter.Params) contextModifier { + return func(ctx *Context) { + ctx.Params = params + } +} + +func withHeader(header string, values ...string) contextModifier { + return func(ctx *Context) { + for _, value := range values { + ctx.Request.Header.Add(header, value) + } + } +} + +func withOperation(operation *openapi3.Operation) contextModifier { + return func(ctx *Context) { + ctx.Operation.Operation = operation + } +} + +func withResponseWriter(writer http.ResponseWriter) contextModifier { + return func(ctx *Context) { + ctx.Writer = writer + } } func TestNewBadRequestErr(t *testing.T) { @@ -61,9 +123,9 @@ func TestErrorHandler(t *testing.T) { errorHandler := ErrorHandler(func(c *Context, err error) (Response[ErrorResponse], error) { return SendText(ErrorResponse{Message: err.Error()}).Status(400), nil }) - assert.Equal(t, nilType, errorHandler.requestTypes().requestBody) - assert.Equal(t, nilType, errorHandler.requestTypes().pathParams) - assert.Equal(t, nilType, errorHandler.requestTypes().queryParams) + assert.Equal(t, utils.NilType, errorHandler.requestTypes().requestBody) + assert.Equal(t, utils.NilType, errorHandler.requestTypes().pathParams) + assert.Equal(t, utils.NilType, errorHandler.requestTypes().queryParams) assert.Equal(t, handlerResponses{ 200: { status: 200, diff --git a/router/schema_validator/array_schema_validator.go b/router/schema_validator/array_schema_validator.go index 9a86bc7..31aac9b 100644 --- a/router/schema_validator/array_schema_validator.go +++ b/router/schema_validator/array_schema_validator.go @@ -1,16 +1,23 @@ package schema_validator import ( - "fmt" - "reflect" + "github.com/getkin/kin-openapi/openapi3" ) -func (c typeSchemaValidatorContext) validateArraySchema() error { - if c.schema.Type != arraySchemaType { - return nil +func (c typeSchemaValidatorContext) validateArraySchema() { + isGoTypeArray := isArrayGoType(c.goType) + if c.schema.Type == openapi3.TypeArray && !isGoTypeArray { + c.err(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) } - if c.goType.Kind() != reflect.Array && c.goType.Kind() != reflect.Slice { - return fmt.Errorf(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) + + if !isSchemaTypeArrayOrEmpty(c.schema) { + if isGoTypeArray && !isSliceOfBytes(c.goType) { + c.err(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) + } + return + } + + if isGoTypeArray && c.schema.Items != nil { + _ = c.WithSchemaAndType(*c.schema.Items.Value, c.goType.Elem()).Validate() } - return c.WithSchemaAndType(*c.schema.Items.Value, c.goType.Elem()).Validate() } diff --git a/router/schema_validator/array_schema_validator_test.go b/router/schema_validator/array_schema_validator_test.go index 1fe4c31..2a22bda 100644 --- a/router/schema_validator/array_schema_validator_test.go +++ b/router/schema_validator/array_schema_validator_test.go @@ -5,17 +5,40 @@ import ( "testing" "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" + + "github.com/piiano/cellotape/router/utils" ) +// according to the spec the array validation properties should apply oly when the type is set to array +func TestArraySchemaValidator(t *testing.T) { + // create with NewSchema and not with NewArraySchema for an untyped schema + schemaWithItemsProperty := openapi3.NewArraySchema(). + WithItems(openapi3.NewObjectSchema(). + WithProperty("token_id", openapi3.NewStringSchema())) + goType := utils.GetType[[]struct { + Value string `json:"token_id"` + }]() + err := schemaValidator(*schemaWithItemsProperty).WithType(goType).Validate() + require.NoErrorf(t, err, "expect untyped schema to be compatible with %s type", goType) +} + // according to the spec the array validation properties should apply oly when the type is set to array func TestArraySchemaValidatorWithUntypedSchema(t *testing.T) { // create with NewSchema and not with NewArraySchema for an untyped schema untypedSchemaWithItemsProperty := openapi3.NewSchema().WithItems(openapi3.NewStringSchema()) - validator := schemaValidator(*untypedSchemaWithItemsProperty) - for _, validType := range types { - t.Run(validType.String(), func(t *testing.T) { - if err := validator.WithType(validType).validateArraySchema(); err != nil { - t.Errorf("expect untyped schema to be compatible with %s type", validType) + + for _, goType := range types { + valid := (kindIs(reflect.Array, reflect.Slice)(goType) && isSerializedFromString(goType.Elem())) || + !kindIs(reflect.Array, reflect.Slice)(goType) || + (uuidType.ConvertibleTo(goType)) + + t.Run(goType.String(), func(t *testing.T) { + err := schemaValidator(*untypedSchemaWithItemsProperty).WithType(goType).Validate() + if valid { + require.NoErrorf(t, err, "expect untyped schema to be compatible with %s type", goType) + } else { + require.Errorf(t, err, "expect untyped schema to be incompatible with %s type", goType) } }) } diff --git a/router/schema_validator/assertions.go b/router/schema_validator/assertions.go new file mode 100644 index 0000000..24cafa4 --- /dev/null +++ b/router/schema_validator/assertions.go @@ -0,0 +1,147 @@ +package schema_validator + +import ( + "reflect" + "time" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/google/uuid" + + "github.com/piiano/cellotape/router/utils" +) + +var ( + timeType = utils.GetType[time.Time]() + uuidType = utils.GetType[uuid.UUID]() + sliceOfBytesType = utils.GetType[[]byte]() + + isString = kindIs(reflect.String) + isUUIDCompatible = anyOf(isString, convertibleTo(uuidType)) + isSliceOfBytes = anyOf(isString, typeIs(sliceOfBytesType)) + isTimeCompatible = anyOf(isString, convertibleTo(timeType)) + isSerializedFromString = anyOf(isString, isUUIDCompatible, isTimeCompatible, isSliceOfBytes) + + isTimeFormat = schemaFormatIs(dateTimeFormat, timeFormat) + isSchemaStringFormat = schemaFormatIs(uuidFormat, byteFormat, dateTimeFormat, timeFormat, dateFormat, durationFormat, + emailFormat, idnEmailFormat, hostnameFormat, idnHostnameFormat, ipv4Format, ipv6Format, uriFormat, + uriReferenceFormat, iriFormat, iriReferenceFormat, uriTemplateFormat, jsonPointerFormat, + relativeJsonPointerFormat, regexFormat, passwordFormat) + + isSerializedFromObject = allOf(kindIs(reflect.Struct, reflect.Map), not(isTimeCompatible)) + + isSchemaTypeStringOrEmpty = schemaTypeIs(openapi3.TypeString, "") + isSchemaTypeBooleanOrEmpty = schemaTypeIs(openapi3.TypeBoolean, "") + isSchemaTypeObjectOrEmpty = schemaTypeIs(openapi3.TypeObject, "") + isSchemaTypeArrayOrEmpty = schemaTypeIs(openapi3.TypeArray, "") + isSchemaTypeNumberOrEmpty = schemaTypeIs(openapi3.TypeNumber, "") + + isBoolType = kindIs(reflect.Bool) + isInt32 = kindIs(reflect.Int32) + isInt64 = kindIs(reflect.Int64) + isFloat32 = kindIs(reflect.Float32) + isFloat64 = kindIs(reflect.Float64) + isNumericType = kindIs(reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64) + + isArrayGoType = allOf(kindIs(reflect.Array, reflect.Slice), not(isUUIDCompatible)) +) + +func isAny(t reflect.Type) bool { + return t.Kind() == reflect.Interface && t.NumMethod() == 0 +} + +type assertion[T any] func(t T) bool +type typeAssertion = assertion[reflect.Type] +type schemaAssertion = assertion[openapi3.Schema] + +func anyOf[A assertion[T], T any](assertions ...A) A { + return func(t T) bool { + for _, assert := range assertions { + if assert(t) { + return true + } + } + return false + } +} + +func allOf[T any, A assertion[T]](assertions ...A) A { + return func(t T) bool { + for _, assert := range assertions { + if !assert(t) { + return false + } + } + return true + } +} + +func not[T any, A assertion[T]](assertion A) A { + return func(t T) bool { + return !assertion(t) + } +} + +func schemaTypeIs(types ...string) schemaAssertion { + set := utils.NewSet(types...) + return func(s openapi3.Schema) bool { + return set.Has(s.Type) + } +} + +func schemaFormatIs(types ...string) schemaAssertion { + set := utils.NewSet(types...) + return func(s openapi3.Schema) bool { + return set.Has(s.Format) + } +} + +func kindIs(kinds ...reflect.Kind) typeAssertion { + set := utils.NewSet(kinds...) + return handleMultiType(func(t reflect.Type) bool { + return set.Has(t.Kind()) + }) +} + +func typeIs(types ...reflect.Type) typeAssertion { + set := utils.NewSet(types...) + return handleMultiType(func(t reflect.Type) bool { + return set.Has(t) + }) +} + +func convertibleTo(targets ...reflect.Type) typeAssertion { + return handleMultiType(func(t reflect.Type) bool { + for _, target := range targets { + if target.ConvertibleTo(t) { + return true + } + } + return false + }) +} + +func handleMultiType(assertion typeAssertion) typeAssertion { + return func(t reflect.Type) bool { + if !utils.IsMultiType(t) { + if t.Kind() == reflect.Pointer { + return assertion(t.Elem()) + } else { + return assertion(t) + } + } + + types, err := utils.ExtractMultiTypeTypes(t) + if err != nil { + return false + } + + for _, mtType := range types { + if assertion(mtType.Elem()) { + return true + } + } + + return false + } +} diff --git a/router/schema_validator/boolean_schema_validator.go b/router/schema_validator/boolean_schema_validator.go index 8103634..85b0a33 100644 --- a/router/schema_validator/boolean_schema_validator.go +++ b/router/schema_validator/boolean_schema_validator.go @@ -1,16 +1,12 @@ package schema_validator -import ( - "fmt" - "reflect" -) +import "github.com/getkin/kin-openapi/openapi3" -func (c typeSchemaValidatorContext) validateBooleanSchema() error { - if c.schema.Type != booleanSchemaType { - return nil - } - if c.goType.Kind() != reflect.Bool { - return fmt.Errorf(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) +func (c typeSchemaValidatorContext) validateBooleanSchema() { + + isTypeBool := isBoolType(c.goType) + if (isTypeBool && !isSchemaTypeBooleanOrEmpty(c.schema)) || + (c.schema.Type == openapi3.TypeBoolean && !isTypeBool) { + c.err(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) } - return nil } diff --git a/router/schema_validator/boolean_schema_validator_test.go b/router/schema_validator/boolean_schema_validator_test.go index 5c2a145..b1dbccf 100644 --- a/router/schema_validator/boolean_schema_validator_test.go +++ b/router/schema_validator/boolean_schema_validator_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" "github.com/piiano/cellotape/router/utils" ) @@ -12,9 +13,8 @@ import ( func TestBooleanSchemaValidatorPassForBoolType(t *testing.T) { booleanSchema := openapi3.NewBoolSchema() validator := schemaValidator(*booleanSchema) - if err := validator.WithType(boolType).Validate(); err != nil { - expectTypeToBeCompatible(t, validator, boolType, "expect boolean schema to be compatible with %s type", boolType) - } + err := validator.WithType(boolType).Validate() + require.NoErrorf(t, err, "expect boolean schema to be compatible with %s type", boolType) } // according to the spec the boolean validation properties should apply only when the type is set to boolean @@ -23,7 +23,7 @@ func TestBoolSchemaValidatorWithUntypedSchema(t *testing.T) { validator := schemaValidator(*untypedSchema) for _, validType := range types { t.Run(validType.String(), func(t *testing.T) { - if err := validator.WithType(validType).validateBooleanSchema(); err != nil { + if err := validator.WithType(validType).Validate(); err != nil { t.Errorf("expect untyped schema to be compatible with %s type", validType) } }) diff --git a/router/schema_validator/error_messages.go b/router/schema_validator/error_messages.go index 418b705..c2e365c 100644 --- a/router/schema_validator/error_messages.go +++ b/router/schema_validator/error_messages.go @@ -10,28 +10,15 @@ import ( func schemaAllOfPropertyIncompatibleWithType(invalidOptions int, options int, goType reflect.Type) string { return fmt.Sprintf("%d/%d schemas defined in allOf are incompatible with type %s", invalidOptions, options, goType) } -func schemaAnyOfPropertyIncompatibleWithType(options int, goType reflect.Type) string { - subject, be := "the schema", "is" - if options > 1 { - subject, be = fmt.Sprintf("all %d schemas", options), "are" - } - return fmt.Sprintf("%s defined in schema anyOf property %s incompatible with type %s", subject, be, goType) -} func schemaTypeWithFormatIsIncompatibleWithType(schema openapi3.Schema, goType reflect.Type) string { return fmt.Sprintf("%s schema with %s format is incompatible with type %s", schema.Type, schema.Format, goType) } + func schemaTypeIsIncompatibleWithType(schema openapi3.Schema, goType reflect.Type) string { return fmt.Sprintf("%s schema is incompatible with type %s", schema.Type, goType) } -func formatMustHaveNoError(err error, schemaType string, goType reflect.Type) error { - if err != nil { - return fmt.Errorf("fail validating %s schema with type %s. %s", schemaType, goType, err.Error()) - } - return nil -} - func schemaPropertyIsNotMappedToFieldInType(name string, fieldType reflect.Type) string { return fmt.Sprintf("property %q is not mapped to a field in type %s", name, fieldType) } diff --git a/router/schema_validator/integer_schema_validator.go b/router/schema_validator/integer_schema_validator.go index 05dec90..1c9c7c8 100644 --- a/router/schema_validator/integer_schema_validator.go +++ b/router/schema_validator/integer_schema_validator.go @@ -2,29 +2,31 @@ package schema_validator import ( "reflect" + + "github.com/getkin/kin-openapi/openapi3" ) -func (c typeSchemaValidatorContext) validateIntegerSchema() error { - l := c.newLogger() - if c.schema.Type != integerSchemaType { - return nil +func (c typeSchemaValidatorContext) validateIntegerSchema() { + + if c.schema.Type != openapi3.TypeInteger { + return } switch c.goType.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: default: - l.Logf(c.level, schemaTypeIsIncompatibleWithType(c.schema, c.goType)) + c.err(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) } switch c.schema.Format { case int32Format: if c.goType.Kind() != reflect.Int32 { - l.Logf(c.level, schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) + c.err(schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) } case int64Format: if c.goType.Kind() != reflect.Int64 { - l.Logf(c.level, schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) + c.err(schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) } } + // TODO: check type compatability with Max, ExclusiveMax, Min, and ExclusiveMin - return formatMustHaveNoError(l.MustHaveNoErrors(), c.schema.Type, c.goType) } diff --git a/router/schema_validator/integer_schema_validator_test.go b/router/schema_validator/integer_schema_validator_test.go index c4209bc..1d2eef2 100644 --- a/router/schema_validator/integer_schema_validator_test.go +++ b/router/schema_validator/integer_schema_validator_test.go @@ -27,7 +27,7 @@ func TestIntegerSchemaValidatorWithUntypedSchema(t *testing.T) { validator := schemaValidator(*untypedSchemaWithInt64Format) for _, validType := range types { t.Run(validType.String(), func(t *testing.T) { - if err := validator.WithType(validType).validateIntegerSchema(); err != nil { + if err := validator.WithType(validType).Validate(); err != nil { t.Errorf("expect untyped schema to be compatible with %s type", validType) } }) diff --git a/router/schema_validator/multi_schemas_validator.go b/router/schema_validator/multi_schemas_validator.go index 969df63..5c648a5 100644 --- a/router/schema_validator/multi_schemas_validator.go +++ b/router/schema_validator/multi_schemas_validator.go @@ -1,27 +1,58 @@ package schema_validator -import "github.com/piiano/cellotape/router/utils" +import ( + "reflect" -type schemaValidation struct { - context TypeSchemaValidator - originalIndex int - logger utils.InMemoryLogger -} + "github.com/getkin/kin-openapi/openapi3" + + "github.com/piiano/cellotape/router/utils" +) + +func (c typeSchemaValidatorContext) matchAllSchemaValidator(name string, schemas openapi3.SchemaRefs) { + if schemas == nil { + return + } -func validateMultipleSchemas(cs ...TypeSchemaValidator) ([]schemaValidation, []schemaValidation) { - pass := make([]schemaValidation, 0) - failed := make([]schemaValidation, 0) - for i, c := range cs { - logger := utils.NewInMemoryLoggerWithLevel(c.logLevel()) - c.WithLogger(logger) - err := c.Validate() - validation := schemaValidation{context: c, logger: logger, originalIndex: i} - if err == nil { - pass = append(pass, validation) + types := []reflect.Type{c.goType} + + if utils.IsMultiType(c.goType) { + if multiTypeTypes, err := utils.ExtractMultiTypeTypes(c.goType); err != nil { + c.err(err.Error()) + } else { + types = multiTypeTypes } - if err != nil { - failed = append(pass, validation) + } + + usedTypes := utils.NewSet[string]() + +schemas: + for index, schema := range schemas { + schemaErrors := make([]string, 0) + + for _, multiTypeType := range types { + typeValidator := typeSchemaValidatorContext{ + errors: new([]string), + schema: *schema.Value, + goType: multiTypeType, + } + + if err := typeValidator.Validate(); err == nil { + usedTypes.Add(multiTypeType.String()) + continue schemas + } + + schemaErrors = append(schemaErrors, *typeValidator.errors...) + } + + c.err("%s schema at index %d didn't match type of %q", name, index, c.goType) + *c.errors = append(*c.errors, schemaErrors...) + } + + if utils.IsMultiType(c.goType) { + for _, multiTypeType := range types { + if !usedTypes.Has(multiTypeType.String()) { + c.err("non of %s schemas match type %q of %q", name, multiTypeType, c.goType) + } } } - return pass, failed } diff --git a/router/schema_validator/multi_schemas_validator_test.go b/router/schema_validator/multi_schemas_validator_test.go new file mode 100644 index 0000000..c9957c3 --- /dev/null +++ b/router/schema_validator/multi_schemas_validator_test.go @@ -0,0 +1,244 @@ +package schema_validator + +import ( + "reflect" + "testing" + "time" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/piiano/cellotape/router/utils" +) + +type composer func(...*openapi3.Schema) *openapi3.Schema + +type multiSchemaCase struct { + name string + composer +} + +var multiSchemaCases = []multiSchemaCase{ + {name: "oneOf", composer: openapi3.NewOneOfSchema}, + {name: "anyOf", composer: openapi3.NewAnyOfSchema}, +} + +func TestMultiSchemaValidator(t *testing.T) { + for _, schemaCase := range multiSchemaCases { + t.Run(schemaCase.name, func(t *testing.T) { + testCases := []struct { + name string + goType reflect.Type + schema *openapi3.Schema + errAssertion func(require.TestingT, error, ...any) + }{ + { + name: "multiple different types", + goType: utils.GetType[any](), + schema: schemaCase.composer( + openapi3.NewBoolSchema(), + openapi3.NewStringSchema(), + openapi3.NewInt64Schema(), + ), + errAssertion: require.NoError, + }, + { + name: "multiple different types", + goType: reflect.TypeOf(&utils.MultiType[struct { + A *bool + B *string + C *int64 + }]{}), + schema: schemaCase.composer( + openapi3.NewBoolSchema(), + openapi3.NewStringSchema(), + openapi3.NewInt64Schema(), + ), + errAssertion: require.NoError, + }, + { + name: "simple go type for schema with multiple different types", + goType: boolType, + schema: schemaCase.composer( + openapi3.NewBoolSchema(), + openapi3.NewStringSchema(), + openapi3.NewInt64Schema(), + ), + errAssertion: require.Error, + }, + { + name: "missing go type for schema with multiple different types", + goType: reflect.TypeOf(&utils.MultiType[struct { + A *bool + B *string + }]{}), + schema: schemaCase.composer( + openapi3.NewBoolSchema(), + openapi3.NewStringSchema(), + openapi3.NewInt64Schema(), + ), + errAssertion: require.Error, + }, + { + name: "missing schema type for multi type go type", + goType: reflect.TypeOf(&utils.MultiType[struct { + A *bool + B *string + C *int64 + }]{}), + schema: schemaCase.composer( + openapi3.NewBoolSchema(), + openapi3.NewStringSchema(), + ), + errAssertion: require.Error, + }, + { + name: "insignificant of types and schema order", + goType: reflect.TypeOf(&utils.MultiType[struct { + B *string + A *bool + }]{}), + schema: schemaCase.composer( + openapi3.NewBoolSchema(), + openapi3.NewStringSchema(), + ), + errAssertion: require.NoError, + }, + { + name: "simple single type with different validations", + goType: intType, + schema: func() *openapi3.Schema { + schema := schemaCase.composer( + &openapi3.Schema{MultipleOf: utils.Ptr(3.0)}, + &openapi3.Schema{MultipleOf: utils.Ptr(5.0)}, + ) + schema.Type = openapi3.TypeNumber + return schema + }(), + errAssertion: require.NoError, + }, + { + name: "simple uuid or time", + goType: reflect.TypeOf(&utils.MultiType[struct { + A *uuid.UUID + B *time.Time + }]{}), + schema: func() *openapi3.Schema { + schema := schemaCase.composer( + &openapi3.Schema{Format: "uuid"}, + &openapi3.Schema{Format: "date-time"}, + ) + schema.Type = openapi3.TypeString + return schema + }(), + errAssertion: require.NoError, + }, + { + name: "integer formats", + goType: int64Type, + schema: func() *openapi3.Schema { + schema := schemaCase.composer( + &openapi3.Schema{Format: int64Format}, + &openapi3.Schema{Format: int32Format}, + ) + schema.Type = openapi3.TypeInteger + return schema + }(), + errAssertion: require.NoError, + }, + { + name: "one struct or another", + goType: reflect.TypeOf(&utils.MultiType[struct { + A *struct{ A string } + B *struct{ B string } + }]{}), + schema: func() *openapi3.Schema { + schema := schemaCase.composer( + &openapi3.Schema{Properties: openapi3.Schemas{"A": openapi3.NewStringSchema().NewRef()}}, + &openapi3.Schema{Properties: openapi3.Schemas{"B": openapi3.NewStringSchema().NewRef()}}, + ) + schema.Type = openapi3.TypeObject + return schema + }(), + errAssertion: require.NoError, + }, + { + name: "one struct or another with error", + goType: reflect.TypeOf(&utils.MultiType[struct { + A *struct{ A string } + B *bool + }]{}), + schema: func() *openapi3.Schema { + schema := schemaCase.composer( + &openapi3.Schema{Properties: openapi3.Schemas{"A": openapi3.NewStringSchema().NewRef()}}, + &openapi3.Schema{Properties: openapi3.Schemas{"B": openapi3.NewStringSchema().NewRef()}}, + ) + schema.Type = openapi3.TypeObject + return schema + }(), + errAssertion: require.Error, + }, + { + name: "one array or another", + goType: reflect.TypeOf(&utils.MultiType[struct { + A *[]string + B *[]bool + }]{}), + schema: func() *openapi3.Schema { + schema := schemaCase.composer( + &openapi3.Schema{Items: openapi3.NewStringSchema().NewRef()}, + &openapi3.Schema{Items: openapi3.NewBoolSchema().NewRef()}, + ) + schema.Type = openapi3.TypeArray + return schema + }(), + errAssertion: require.NoError, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + err := schemaValidator(*testCase.schema).WithType(testCase.goType).Validate() + testCase.errAssertion(t, err) + }) + } + }) + } +} + +func TestSchemaMultiSchemaValidatorFailOnNoMatchedType(t *testing.T) { + for _, schemaCase := range multiSchemaCases { + t.Run(schemaCase.name, func(t *testing.T) { + schema := schemaCase.composer( + openapi3.NewBoolSchema(), + openapi3.NewStringSchema(), + openapi3.NewInt64Schema(), + ) + validator := schemaValidator(*schema) + + errTemplate := "expect schema with %s property to be incompatible with %s type" + for _, invalidType := range types { + t.Run(invalidType.String(), func(t *testing.T) { + expectTypeToBeIncompatible(t, validator, invalidType, errTemplate, schemaCase.name, invalidType) + }) + } + }) + } +} + +func TestCorruptedMultiType(t *testing.T) { + testType := utils.GetType[*utils.MultiType[bool]]() + + isMultiType := utils.IsMultiType(testType) + require.True(t, isMultiType) + + _, err := utils.ExtractMultiTypeTypes(testType) + require.Error(t, err) + + err = schemaValidator(*openapi3.NewOneOfSchema( + openapi3.NewStringSchema(), + openapi3.NewBoolSchema(), + )).WithType(testType).Validate() + require.Error(t, err, "expect schema to be incompatible with invalid MultiType %s", testType) +} diff --git a/router/schema_validator/number_schema_validator.go b/router/schema_validator/number_schema_validator.go index 804e53a..f4d8504 100644 --- a/router/schema_validator/number_schema_validator.go +++ b/router/schema_validator/number_schema_validator.go @@ -1,31 +1,28 @@ package schema_validator import ( - "reflect" + "github.com/getkin/kin-openapi/openapi3" ) -func (c typeSchemaValidatorContext) validateNumberSchema() error { - l := c.newLogger() - if c.schema.Type != numberSchemaType { - return nil - } - switch c.goType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, - reflect.Float32, reflect.Float64: - default: - l.Logf(c.level, schemaTypeIsIncompatibleWithType(c.schema, c.goType)) - } - switch c.schema.Format { - case floatFormat: - if c.goType.Kind() != reflect.Float32 { - l.Logf(c.level, schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) - } - case doubleFormat: - if c.goType.Kind() != reflect.Float64 { - l.Logf(c.level, schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) +func (c typeSchemaValidatorContext) validateNumberSchema() { + isGoTypeNumeric := isNumericType(c.goType) + + if !isGoTypeNumeric { + if c.schema.Type != openapi3.TypeNumber { + return } + + // schema type is numeric and go type is not numeric + c.err(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) + return + } + + // schema type is numeric and go type is not numeric + if (c.schema.Format == floatFormat && !isFloat32(c.goType)) || + (c.schema.Format == doubleFormat && !isFloat64(c.goType)) { + c.err(schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) + return } + // TODO: check type compatability with Max, ExclusiveMax, Min, and ExclusiveMin - return formatMustHaveNoError(l.MustHaveNoErrors(), c.schema.Type, c.goType) } diff --git a/router/schema_validator/number_schema_validator_test.go b/router/schema_validator/number_schema_validator_test.go index 5640b28..2a6c8ea 100644 --- a/router/schema_validator/number_schema_validator_test.go +++ b/router/schema_validator/number_schema_validator_test.go @@ -6,13 +6,14 @@ import ( "testing" "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" "github.com/piiano/cellotape/router/utils" ) func TestNumberSchemaValidatorPassForIntType(t *testing.T) { numberSchema := openapi3.NewSchema() - numberSchema.Type = numberSchemaType + numberSchema.Type = openapi3.TypeNumber validator := schemaValidator(*numberSchema) errTemplate := "expect number schema to be compatible with %s type" for _, numericType := range numericTypes { @@ -25,11 +26,15 @@ func TestNumberSchemaValidatorPassForIntType(t *testing.T) { // according to the spec the number validation properties should apply only when the type is set to number func TestNumberSchemaValidatorWithUntypedSchema(t *testing.T) { untypedSchemaWithDoubleFormat := openapi3.NewSchema().WithFormat(doubleFormat) - validator := schemaValidator(*untypedSchemaWithDoubleFormat) - for _, validType := range types { - t.Run(validType.String(), func(t *testing.T) { - if err := validator.WithType(validType).validateNumberSchema(); err != nil { - t.Errorf("expect untyped schema to be compatible with %s type", validType) + for _, goType := range types { + valid := isFloat64(goType) || !isNumericType(goType) + + t.Run(goType.String(), func(t *testing.T) { + err := schemaValidator(*untypedSchemaWithDoubleFormat).WithType(goType).Validate() + if valid { + require.NoErrorf(t, err, "expect untyped schema to be compatible with %s type", goType) + } else { + require.Errorf(t, err, "expect untyped schema to be incompatible with %s type", goType) } }) } @@ -37,7 +42,7 @@ func TestNumberSchemaValidatorWithUntypedSchema(t *testing.T) { func TestNumberSchemaValidatorFailOnWrongType(t *testing.T) { numberSchema := openapi3.NewSchema() - numberSchema.Type = numberSchemaType + numberSchema.Type = openapi3.TypeNumber validator := schemaValidator(*numberSchema) errTemplate := "expect number schema to be incompatible with %s type" // filter all numeric types from all defined test types @@ -56,7 +61,7 @@ func TestNumberSchemaValidatorFailOnWrongType(t *testing.T) { func TestFloatFormatSchemaValidatorPassForFloat32Type(t *testing.T) { floatSchema := openapi3.NewSchema() - floatSchema.Type = numberSchemaType + floatSchema.Type = openapi3.TypeNumber floatSchema.Format = floatFormat validator := schemaValidator(*floatSchema) errTemplate := "expect number schema with float format to be compatible with %s type" @@ -65,7 +70,7 @@ func TestFloatFormatSchemaValidatorPassForFloat32Type(t *testing.T) { func TestFloat32FormatSchemaValidatorFailOnWrongType(t *testing.T) { floatSchema := openapi3.NewSchema() - floatSchema.Type = numberSchemaType + floatSchema.Type = openapi3.TypeNumber floatSchema.Format = floatFormat validator := schemaValidator(*floatSchema) errTemplate := "expect number schema with float format to be incompatible with %s type" @@ -82,7 +87,7 @@ func TestFloat32FormatSchemaValidatorFailOnWrongType(t *testing.T) { func TestDoubleFormatSchemaValidatorPassForFloat32Type(t *testing.T) { doubleSchema := openapi3.NewSchema() - doubleSchema.Type = numberSchemaType + doubleSchema.Type = openapi3.TypeNumber doubleSchema.Format = doubleFormat validator := schemaValidator(*doubleSchema) errTemplate := "expect number schema with double format to be compatible with %s type" @@ -91,7 +96,7 @@ func TestDoubleFormatSchemaValidatorPassForFloat32Type(t *testing.T) { func TestDoubleFormatSchemaValidatorFailOnWrongType(t *testing.T) { doubleSchema := openapi3.NewSchema() - doubleSchema.Type = numberSchemaType + doubleSchema.Type = openapi3.TypeNumber doubleSchema.Format = doubleFormat validator := schemaValidator(*doubleSchema) errTemplate := "expect number schema with double format to be incompatible with %s type" diff --git a/router/schema_validator/object_schema_validator.go b/router/schema_validator/object_schema_validator.go index bf042e7..431f0aa 100644 --- a/router/schema_validator/object_schema_validator.go +++ b/router/schema_validator/object_schema_validator.go @@ -14,84 +14,96 @@ import ( var textMarshallerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() -func (c typeSchemaValidatorContext) validateObjectSchema() error { +func (c typeSchemaValidatorContext) validateObjectSchema() { // TODO: validate required properties, nullable, additionalProperties, etc. - l := c.newLogger() - if c.schema.Type != objectSchemaType { - return nil + serializedFromObject := isSerializedFromObject(c.goType) + + if !serializedFromObject { + if c.schema.Type == openapi3.TypeObject { + c.err(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) + } + return } + + if !isSchemaTypeObjectOrEmpty(c.schema) { + c.err(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) + } + + handleMultiType(func(t reflect.Type) bool { + if t.Kind() == reflect.Struct { + return c.assertStruct(t) + } + + if t.Kind() == reflect.Map { + return c.assertMap(t) + } + + return false + })(c.goType) +} + +func (c typeSchemaValidatorContext) assertStruct(t reflect.Type) bool { properties := c.schema.Properties if properties == nil { properties = make(map[string]*openapi3.SchemaRef, 0) } - switch c.goType.Kind() { - case reflect.Struct: - if properties != nil { - // TODO: add support for receiving in options how to extract type keys (to support schema for non-json serializers) - fields := structJsonFields(c.goType) - for name, field := range fields { - property, ok := properties[name] - if !ok { - if c.schema.AdditionalProperties == nil && - (c.schema.AdditionalPropertiesAllowed == nil || *c.schema.AdditionalPropertiesAllowed) { - continue - } - - if c.schema.AdditionalPropertiesAllowed != nil && !*c.schema.AdditionalPropertiesAllowed { - l.Logf(c.level, fmt.Sprintf("field %q (%q) with type %s not found in object schema properties", field.Name, name, field.Type)) - continue - } - if c.schema.AdditionalProperties != nil && c.schema.AdditionalProperties.Value != nil { - if err := c.WithSchema(*c.schema.AdditionalProperties.Value).WithType(field.Type).Validate(); err != nil { - l.Logf(c.level, fmt.Sprintf("field %q (%q) with type %s not found in object schema properties nor additonal properties", field.Name, name, field.Type)) - } - } - continue - } - if err := c.WithType(field.Type).WithSchema(*property.Value).Validate(); err != nil { - l.Logf(c.level, schemaPropertyIsIncompatibleWithFieldType(name, field.Name, field.Type)) - } - } - for name, property := range properties { - field, ok := fields[name] - if !ok { - l.Logf(c.level, schemaPropertyIsNotMappedToFieldInType(name, c.goType)) - continue - } - if err := c.WithType(field.Type).WithSchema(*property.Value).Validate(); err != nil { - l.Logf(c.level, schemaPropertyIsIncompatibleWithFieldType(name, field.Name, field.Type)) - } + // TODO: add support for receiving in options how to extract type keys (to support schema for non-json serializers) + fields := structJsonFields(t) + + validatedFields := utils.NewSet[string]() + for name, field := range fields { + + validatedFields.Add(name) + if property, ok := properties[name]; !ok { + if additionalProperties := additionalPropertiesSchema(c.schema); additionalProperties == nil { + c.err(fmt.Sprintf("field %q (%q) with type %s not found in object schema properties", field.Name, name, field.Type)) + } else if err := c.WithSchema(*additionalProperties).WithType(field.Type).Validate(); err != nil { + c.err(fmt.Sprintf("field %q (%q) with type %s not found in object schema properties nor additonal properties", field.Name, name, field.Type)) } + } else if err := c.WithType(field.Type).WithSchema(*property.Value).Validate(); err != nil { + c.err(schemaPropertyIsIncompatibleWithFieldType(name, field.Name, field.Type)) } - case reflect.Map: - keyType := c.goType.Key() - mapValueType := c.goType.Elem() - switch keyType.Kind() { - case reflect.String, - reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - default: - if !keyType.Implements(textMarshallerType) { - l.Logf(c.level, "object schema with map type must have a string compatible type. %s key is not string compatible", keyType) - } + } + for name := range properties { + if !validatedFields.Has(name) { + c.err(schemaPropertyIsNotMappedToFieldInType(name, t)) } - if properties != nil { - for name, property := range properties { - // check if property name is compatible with the map key type - keyName, _ := json.Marshal(name) - if err := json.Unmarshal(keyName, reflect.New(keyType).Interface()); err != nil { - l.Logf(c.level, "schema property name %q is incompatible with map key type %s", name, keyType) - } - // check if property schema is compatible with the map value type - if err := c.WithType(mapValueType).WithSchema(*property.Value).Validate(); err != nil { - l.Logf(c.level, "schema property %q is incompatible with map value type %s", name, mapValueType) - } + } + + return len(*c.errors) == 0 +} +func (c typeSchemaValidatorContext) assertMap(t reflect.Type) bool { + keyType := t.Key() + mapValueType := t.Elem() + + // From the internal implementation of json.Marshal & json Unmarshal: + // Map key must either have string kind, have an integer kind, + // or be an encoding.TextUnmarshaler. + switch keyType.Kind() { + case reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + default: + if !keyType.Implements(textMarshallerType) { + c.err("object schema with map type must have a string compatible type. %s key is not string compatible", keyType) + } + } + if c.schema.Properties != nil { + + for name, property := range c.schema.Properties { + // check if property name is compatible with the map key type + keyName, _ := json.Marshal(name) + if err := json.Unmarshal(keyName, reflect.New(keyType).Interface()); err != nil { + c.err("schema property name %q is incompatible with map key type %s", name, keyType) + } + // check if property schema is compatible with the map value type + if err := c.WithType(mapValueType).WithSchema(*property.Value).Validate(); err != nil { + c.err("schema property %q is incompatible with map value type %s", name, mapValueType) } } - default: - l.Logf(c.level, "object schema must be a struct type or a map. %s type is incompatible", c.goType) } - return formatMustHaveNoError(l.MustHaveNoErrors(), c.schema.Type, c.goType) + + return len(*c.errors) == 0 } // structJsonFields Extract the struct fields that are serializable as JSON corresponding to their JSON key @@ -140,3 +152,18 @@ func structJsonFields(structType reflect.Type) map[string]reflect.StructField { } return fields } + +func additionalPropertiesSchema(schema openapi3.Schema) *openapi3.Schema { + // if additional properties schema is defined explicitly return it + if schema.AdditionalProperties != nil { + return schema.AdditionalProperties.Value + } + + // if additional properties is empty (tru by default) or set explicitly to true return an empty schema (schema for any type) + if schema.AdditionalPropertiesAllowed == nil || *schema.AdditionalPropertiesAllowed { + return openapi3.NewSchema() + } + + // return nil if additional properties are not allowed + return nil +} diff --git a/router/schema_validator/object_schema_validator_test.go b/router/schema_validator/object_schema_validator_test.go index 7f8d30a..a149baf 100644 --- a/router/schema_validator/object_schema_validator_test.go +++ b/router/schema_validator/object_schema_validator_test.go @@ -5,17 +5,54 @@ import ( "testing" "github.com/getkin/kin-openapi/openapi3" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/piiano/cellotape/router/utils" ) +// according to the spec the object validation properties should apply only when the type is set to object +func TestObjectSchemaValidatorWithImplementsTextMarshallerKey(t *testing.T) { + untypedSchemaWithProperty := openapi3.NewObjectSchema() + + testTypes := []reflect.Type{ + utils.GetType[map[uuid.UUID]bool](), + utils.GetType[map[uuid.Time]bool](), + } + + for _, testType := range testTypes { + t.Run(testType.String(), func(t *testing.T) { + err := schemaValidator(*untypedSchemaWithProperty).WithType(testType).Validate() + require.NoErrorf(t, err, "expect schema to be compatible with %s type", testType) + }) + } +} + +// according to the spec the object validation properties should apply only when the type is set to object +func TestObjectSchemaValidatorWithNonSerializableMapKey(t *testing.T) { + untypedSchemaWithProperty := openapi3.NewObjectSchema() + + // json Marshaller and Unmarshaller don't support keys from type struct{} + testType := utils.GetType[map[struct{}]bool]() + + err := schemaValidator(*untypedSchemaWithProperty).WithType(testType).Validate() + require.Errorf(t, err, "expect schema to be incompatible with %s type", testType) +} + // according to the spec the object validation properties should apply only when the type is set to object func TestObjectSchemaValidatorWithUntypedSchema(t *testing.T) { untypedSchemaWithProperty := openapi3.NewSchema().WithProperty("name", openapi3.NewStringSchema()) - validator := schemaValidator(*untypedSchemaWithProperty) - for _, validType := range types { - t.Run(validType.String(), func(t *testing.T) { - if err := validator.WithType(validType).validateObjectSchema(); err != nil { - t.Errorf("expect untyped schema to be compatible with %s type", validType) - } + + for _, testType := range types { + if testType.Kind() == reflect.Struct || + testType.Kind() == reflect.Map || + (testType.Kind() == reflect.Pointer && (testType.Elem().Kind() == reflect.Struct || + testType.Elem().Kind() == reflect.Map)) { + continue + } + t.Run(testType.String(), func(t *testing.T) { + err := schemaValidator(*untypedSchemaWithProperty).WithType(testType).Validate() + require.NoErrorf(t, err, "expect untyped schema to be compatible with %s type", testType) }) } } @@ -55,26 +92,26 @@ func TestObjectSchemaValidatorWithSimpleStructAdditionalProperties(t *testing.T) simpleStructSchema := openapi3.NewObjectSchema(). WithProperty("Field1", openapi3.NewStringSchema()). WithProperty("Field2", openapi3.NewIntegerSchema()) - validator := schemaValidator(*simpleStructSchema) - simpleStructType := reflect.TypeOf(SimpleStruct{}) + + simpleStructType := utils.GetType[SimpleStruct]() errTemplate := "expect object schema to be %s with %s type" - expectTypeToBeCompatible(t, validator, simpleStructType, errTemplate, "compatible", simpleStructType) + expectTypeToBeCompatible(t, schemaValidator(*simpleStructSchema), simpleStructType, errTemplate, "compatible", simpleStructType) - expectTypeToBeCompatible(t, validator.WithSchema(*simpleStructSchema.WithAnyAdditionalProperties()), + expectTypeToBeCompatible(t, schemaValidator(*simpleStructSchema).WithSchema(*simpleStructSchema.WithAnyAdditionalProperties()), simpleStructType, errTemplate, "compatible", simpleStructType) explicitWithoutAdditionalProperties := *simpleStructSchema - var f = false - explicitWithoutAdditionalProperties.AdditionalPropertiesAllowed = &f - expectTypeToBeIncompatible(t, validator.WithSchema(explicitWithoutAdditionalProperties), + + explicitWithoutAdditionalProperties.AdditionalPropertiesAllowed = utils.Ptr(false) + expectTypeToBeIncompatible(t, schemaValidator(*simpleStructSchema).WithSchema(explicitWithoutAdditionalProperties), simpleStructType, errTemplate, "incompatible", simpleStructType) - expectTypeToBeIncompatible(t, validator.WithSchema(*simpleStructSchema. + expectTypeToBeIncompatible(t, schemaValidator(*simpleStructSchema).WithSchema(*simpleStructSchema. WithAdditionalProperties(openapi3.NewStringSchema())), simpleStructType, errTemplate, "incompatible", simpleStructType) - expectTypeToBeCompatible(t, validator.WithSchema(*simpleStructSchema. + expectTypeToBeCompatible(t, schemaValidator(*simpleStructSchema).WithSchema(*simpleStructSchema. WithAdditionalProperties(openapi3.NewBoolSchema())), simpleStructType, errTemplate, "compatible", simpleStructType) diff --git a/router/schema_validator/schema_allof_validator.go b/router/schema_validator/schema_allof_validator.go index b1c18d6..4a1d744 100644 --- a/router/schema_validator/schema_allof_validator.go +++ b/router/schema_validator/schema_allof_validator.go @@ -1,12 +1,14 @@ package schema_validator -func (c typeSchemaValidatorContext) validateSchemaAllOf() error { +func (c typeSchemaValidatorContext) validateSchemaAllOf() { if c.schema.AllOf == nil { - return nil + return } - l := c.newLogger() + errors := len(*c.errors) for _, option := range c.schema.AllOf { - l.LogIfNotNil(c.level, c.WithSchema(*option.Value).Validate()) + _ = c.WithSchema(*option.Value).Validate() + } + if len(*c.errors) > errors { + c.err(schemaAllOfPropertyIncompatibleWithType(len(*c.errors)-errors, len(c.schema.AllOf), c.goType)) } - return l.MustHaveNoErrorsf(schemaAllOfPropertyIncompatibleWithType(l.Errors(), len(c.schema.AllOf), c.goType)) } diff --git a/router/schema_validator/schema_anyof_validator.go b/router/schema_validator/schema_anyof_validator.go deleted file mode 100644 index bb3f917..0000000 --- a/router/schema_validator/schema_anyof_validator.go +++ /dev/null @@ -1,26 +0,0 @@ -package schema_validator - -import ( - "github.com/getkin/kin-openapi/openapi3" - - "github.com/piiano/cellotape/router/utils" -) - -func (c typeSchemaValidatorContext) validateSchemaAnyOf() error { - if c.schema.AnyOf == nil { - return nil - } - //l := c.newLogger() - l := utils.NewInMemoryLoggerWithLevel(c.level) - pass, failed := validateMultipleSchemas(utils.Map(c.schema.AnyOf, func(t *openapi3.SchemaRef) TypeSchemaValidator { - return c.WithSchema(*t.Value) - })...) - if len(pass) == 0 { - l.Logf(c.level, "schema with anyOf property is incompatible with type %s", c.goType) - for i, check := range failed { - l.Logf(c.level, "anyOf[%d] didn't match type %s", i, c.goType) - l.Log(c.level, check.logger.Printed()) - } - } - return l.MustHaveNoErrorsf(schemaAnyOfPropertyIncompatibleWithType(len(c.schema.AnyOf), c.goType)) -} diff --git a/router/schema_validator/schema_anyof_validator_test.go b/router/schema_validator/schema_anyof_validator_test.go deleted file mode 100644 index ddc3c4d..0000000 --- a/router/schema_validator/schema_anyof_validator_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package schema_validator - -import ( - "reflect" - "testing" - - "github.com/getkin/kin-openapi/openapi3" - - "github.com/piiano/cellotape/router/utils" -) - -func TestSchemaAnyOfValidatorPass(t *testing.T) { - notBooleanSchema := openapi3.NewSchema() - notBooleanSchema.AnyOf = openapi3.SchemaRefs{ - openapi3.NewBoolSchema().NewRef(), - openapi3.NewStringSchema().NewRef(), - openapi3.NewInt64Schema().NewRef(), - } - validator := schemaValidator(*notBooleanSchema) - var validTypes = []reflect.Type{boolType, stringType, int64Type} - errTemplate := "expect schema with anyOf property to be compatible with %s type" - for _, validType := range validTypes { - t.Run(validType.String(), func(t *testing.T) { - expectTypeToBeCompatible(t, validator, validType, errTemplate, validType) - }) - } -} - -func TestSchemaAnyOfValidatorPassOnMoreThanOneMatchedType(t *testing.T) { - notBooleanSchema := openapi3.NewSchema() - numberSchema := openapi3.NewSchema() - numberSchema.Type = numberSchemaType - notBooleanSchema.AnyOf = openapi3.SchemaRefs{ - openapi3.NewBoolSchema().NewRef(), - openapi3.NewStringSchema().NewRef(), - openapi3.NewInt64Schema().NewRef(), - numberSchema.NewRef(), - } - validator := schemaValidator(*notBooleanSchema) - errTemplate := "expect schema with anyOf property to be compatible with %s type" - expectTypeToBeCompatible(t, validator, int64Type, errTemplate, int64Type) -} - -func TestSchemaAnyOfValidatorFailOnNoMatchedType(t *testing.T) { - notBooleanSchema := openapi3.NewSchema() - notBooleanSchema.AnyOf = openapi3.SchemaRefs{ - openapi3.NewBoolSchema().NewRef(), - openapi3.NewStringSchema().NewRef(), - openapi3.NewInt64Schema().NewRef(), - } - validator := schemaValidator(*notBooleanSchema) - invalidTypes := utils.Filter(types, func(t reflect.Type) bool { - return t != boolType && t != stringType && t != int64Type && - t != reflect.PointerTo(boolType) && t != reflect.PointerTo(stringType) && t != reflect.PointerTo(int64Type) - }) - errTemplate := "expect schema with anyOf property to be incompatible with %s type" - for _, invalidType := range invalidTypes { - t.Run(invalidType.String(), func(t *testing.T) { - expectTypeToBeIncompatible(t, validator, invalidType, errTemplate, invalidType) - }) - } -} diff --git a/router/schema_validator/schema_format_validator.go b/router/schema_validator/schema_format_validator.go index f5a2c3c..c4573d6 100644 --- a/router/schema_validator/schema_format_validator.go +++ b/router/schema_validator/schema_format_validator.go @@ -2,77 +2,80 @@ package schema_validator // Formats https://datatracker.ietf.org/doc/html/draft-bhutton-json-schema-validation-00#section-7.3 const ( - // use with stringSchemaType + // use with openapi3.TypeString // A string instance is valid against this attribute if it is a valid representation according to the "date-time" production. dateTimeFormat = "date-time" - // use with stringSchemaType + // use with openapi3.TypeString // A string instance is valid against this attribute if it is a valid representation according to the "full-date" production. dateFormat = "date" - // use with stringSchemaType + // use with openapi3.TypeString // A string instance is valid against this attribute if it is a valid representation according to the "full-time" production. timeFormat = "time" - // use with stringSchemaType + // use with openapi3.TypeString // A string instance is valid against this attribute if it is a valid representation according to the "duration" production. durationFormat = "duration" - // use with stringSchemaType + // use with openapi3.TypeString // As defined by the "Mailbox" ABNF rule in RFC 5321, section 4.1.2 [RFC5321]. emailFormat = "email" - // use with stringSchemaType + // use with openapi3.TypeString // As defined by the extended "Mailbox" ABNF rule in RFC 6531, section 3.3 [RFC6531]. idnEmailFormat = "idn-email" // As defined by RFC 1123, section 2.1 [RFC1123], including host names produced using the Punycode algorithm specified in RFC 5891, section 4.4 [RFC5891]. hostnameFormat = "hostname" - // use with stringSchemaType + // use with openapi3.TypeString // As defined by either RFC 1123 as for hostname, or an internationalized hostname as defined by RFC 5890, section 2.3.2.3 [RFC5890]. idnHostnameFormat = "idn-hostname" - // use with stringSchemaType + // use with openapi3.TypeString // An IPv4 address according to the "dotted-quad" ABNF syntax as defined in RFC 2673, section 3.2 [RFC2673]. ipv4Format = "ipv4" - // use with stringSchemaType + // use with openapi3.TypeString // An IPv6 address as defined in RFC 4291, section 2.2 [RFC4291]. ipv6Format = "ipv6" - // use with stringSchemaType + // use with openapi3.TypeString // A string instance is valid against this attribute if it is a valid URI, according to [RFC3986]. uriFormat = "uri" - // use with stringSchemaType + // use with openapi3.TypeString // A string instance is valid against this attribute if it is a valid URI Reference (either a URI or a relative-reference), according to [RFC3986]. uriReferenceFormat = "uri-reference" - // use with stringSchemaType + // use with openapi3.TypeString // A string instance is valid against this attribute if it is a valid IRI, according to [RFC3987]. iriFormat = "iri" - // use with stringSchemaType + // use with openapi3.TypeString // A string instance is valid against this attribute if it is a valid IRI Reference (either an IRI or a relative-reference), according to [RFC3987]. iriReferenceFormat = "iri-reference" - // use with stringSchemaType + // use with openapi3.TypeString // A string instance is valid against this attribute if it is a valid string representation of a UUID, according to [RFC4122]. uuidFormat = "uuid" - // use with stringSchemaType + // use with openapi3.TypeString + // A string instance is valid against this attribute if it is a valid base64 string. + byteFormat = "byte" + // use with openapi3.TypeString // This attribute applies to string instances. // A string instance is valid against this attribute if it is a valid URI Template (of any level), according to [RFC6570]. // Note that URI Templates may be used for IRIs; there is no separate IRI Template specification. uriTemplateFormat = "uri-template" - // use with stringSchemaType + // use with openapi3.TypeString // A string instance is valid against this attribute if it is a valid JSON string representation of a JSON Pointer, according to RFC 6901, section 5 [RFC6901]. jsonPointerFormat = "json-pointer" - // use with stringSchemaType + // use with openapi3.TypeString // A string instance is valid against this attribute if it is a valid Relative JSON Pointer [relative-json-pointer]. relativeJsonPointerFormat = "relative-json-pointer" - // use with stringSchemaType + // use with openapi3.TypeString // This attribute applies to string instances. // A regular expression, which SHOULD be valid according to the ECMA-262 [ecma262] regular expression dialect. // Implementations that validate formats MUST accept at least the subset of ECMA-262 defined in the Regular Expressions (Section 4.3) section of this specification, and SHOULD accept all valid ECMA-262 expressions. regexFormat = "regex" // Additional formats defined by OpenAPI spec https://spec.openapis.org/oas/v3.1.0#data-types - // use with integerSchemaType + // use with openapi3.TypeInteger int32Format = "int32" - // use with integerSchemaType + // use with openapi3.TypeInteger int64Format = "int64" - // use with numberSchemaType + // use with openapi3.TypeNumber floatFormat = "float" - // use with numberSchemaType + // use with openapi3.TypeNumber doubleFormat = "double" - // use with stringSchemaType + // use with openapi3.TypeString passwordFormat = "password" ) diff --git a/router/schema_validator/schema_not_validator.go b/router/schema_validator/schema_not_validator.go index 00bc0b5..0b3fc69 100644 --- a/router/schema_validator/schema_not_validator.go +++ b/router/schema_validator/schema_not_validator.go @@ -1,15 +1,13 @@ package schema_validator -import ( - "fmt" -) - -func (c typeSchemaValidatorContext) validateSchemaNot() error { +func (c typeSchemaValidatorContext) validateSchemaNot() { if c.schema.Not == nil { - return nil + return } + errors := len(*c.errors) if err := c.WithSchema(*c.schema.Not.Value).Validate(); err == nil { - return fmt.Errorf("schema with not property is incompatible with type %s", c.goType) + c.err("schema with not property is incompatible with type %s", c.goType) + return } - return nil + *c.errors = (*c.errors)[:errors] } diff --git a/router/schema_validator/schema_oneof_validator.go b/router/schema_validator/schema_oneof_validator.go deleted file mode 100644 index 1d2bfa7..0000000 --- a/router/schema_validator/schema_oneof_validator.go +++ /dev/null @@ -1,31 +0,0 @@ -package schema_validator - -import ( - "github.com/getkin/kin-openapi/openapi3" - - "github.com/piiano/cellotape/router/utils" -) - -func (c typeSchemaValidatorContext) validateSchemaOneOf() error { - if c.schema.OneOf == nil { - return nil - } - l := c.newLogger() - pass, failed := validateMultipleSchemas(utils.Map(c.schema.OneOf, func(t *openapi3.SchemaRef) TypeSchemaValidator { - return c.WithSchema(*t.Value) - })...) - if len(pass) == 0 { - l.Logf(c.level, "schema with oneOf property has no matches for the type %q", c.goType) - for _, check := range failed { - l.Logf(c.level, "oneOf[%d] didn't match type %q", check.originalIndex, c.goType) - l.Log(c.level, check.logger.Printed()) - } - } - if len(pass) > 1 { - l.Logf(c.level, "schema with oneOf property has more than one match for the type %q", c.goType) - for _, check := range pass { - l.Logf(c.level, "oneOf[%d] matched type %q", check.originalIndex, c.goType) - } - } - return l.MustHaveNoErrorsf("schema with oneOf property is incompatible with type %s", c.goType) -} diff --git a/router/schema_validator/schema_oneof_validator_test.go b/router/schema_validator/schema_oneof_validator_test.go deleted file mode 100644 index f52dc25..0000000 --- a/router/schema_validator/schema_oneof_validator_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package schema_validator - -import ( - "reflect" - "testing" - - "github.com/getkin/kin-openapi/openapi3" - - "github.com/piiano/cellotape/router/utils" -) - -func TestSchemaOneOfValidatorPass(t *testing.T) { - notBooleanSchema := openapi3.NewSchema() - notBooleanSchema.OneOf = openapi3.SchemaRefs{ - openapi3.NewBoolSchema().NewRef(), - openapi3.NewStringSchema().NewRef(), - openapi3.NewInt64Schema().NewRef(), - } - validator := schemaValidator(*notBooleanSchema) - var validTypes = []reflect.Type{boolType, stringType, int64Type} - errTemplate := "expect schema with oneOf property to be compatible with %s type" - for _, validType := range validTypes { - t.Run(validType.String(), func(t *testing.T) { - expectTypeToBeCompatible(t, validator, validType, errTemplate, validType) - }) - } -} - -func TestSchemaOneOfValidatorFailOnMoreThanOneMatchedType(t *testing.T) { - notBooleanSchema := openapi3.NewSchema() - numberSchema := openapi3.NewSchema() - numberSchema.Type = numberSchemaType - notBooleanSchema.OneOf = openapi3.SchemaRefs{ - openapi3.NewBoolSchema().NewRef(), - openapi3.NewStringSchema().NewRef(), - openapi3.NewInt64Schema().NewRef(), - numberSchema.NewRef(), - } - validator := schemaValidator(*notBooleanSchema) - errTemplate := "expect schema with oneOf property to be incompatible with %s type" - expectTypeToBeIncompatible(t, validator, int64Type, errTemplate, int64Type) -} - -func TestSchemaOneOfValidatorFailOnNoMatchedType(t *testing.T) { - notBooleanSchema := openapi3.NewSchema() - notBooleanSchema.OneOf = openapi3.SchemaRefs{ - openapi3.NewBoolSchema().NewRef(), - openapi3.NewStringSchema().NewRef(), - openapi3.NewInt64Schema().NewRef(), - } - validator := schemaValidator(*notBooleanSchema) - invalidTypes := utils.Filter(types, func(t reflect.Type) bool { - return t != boolType && t != stringType && t != int64Type && - t != reflect.PointerTo(boolType) && t != reflect.PointerTo(stringType) && t != reflect.PointerTo(int64Type) - }) - errTemplate := "expect schema with oneOf property to be incompatible with %s type" - for _, invalidType := range invalidTypes { - t.Run(invalidType.String(), func(t *testing.T) { - expectTypeToBeIncompatible(t, validator, invalidType, errTemplate, invalidType) - }) - } -} diff --git a/router/schema_validator/schema_validator.go b/router/schema_validator/schema_validator.go index e3de239..dbb571b 100644 --- a/router/schema_validator/schema_validator.go +++ b/router/schema_validator/schema_validator.go @@ -1,6 +1,8 @@ package schema_validator import ( + "errors" + "fmt" "reflect" "github.com/getkin/kin-openapi/openapi3" @@ -8,20 +10,10 @@ import ( "github.com/piiano/cellotape/router/utils" ) -// schema types allowed by OpenAPI specification. -const ( - objectSchemaType = "object" - arraySchemaType = "array" - stringSchemaType = "string" - booleanSchemaType = "boolean" - numberSchemaType = "number" - integerSchemaType = "integer" -) +var ErrSchemaIncompatibleWithType = errors.New("schema is incompatible with type") // TypeSchemaValidator helps validate reflect.Type and openapi3.Schema compatibility using the validation Options. type TypeSchemaValidator interface { - // WithLogger immutably returns a new TypeSchemaValidator with the specified utils.Logger. - WithLogger(utils.Logger) TypeSchemaValidator // WithType immutably returns a new TypeSchemaValidator with the specified reflect.Type to validate. WithType(reflect.Type) TypeSchemaValidator // WithSchema immutably returns a new TypeSchemaValidator with the specified openapi3.Schema to validate. @@ -32,106 +24,95 @@ type TypeSchemaValidator interface { // Returns error with all compatability errors found or nil if compatible. Validate() error - validateSchemaAllOf() error - validateSchemaOneOf() error - validateSchemaAnyOf() error - validateSchemaNot() error - validateObjectSchema() error - validateArraySchema() error - validateStringSchema() error - validateBooleanSchema() error - validateIntegerSchema() error - validateNumberSchema() error - - newLogger() utils.Logger - logLevel() utils.LogLevel + Errors() []string + + matchAllSchemaValidator(string, openapi3.SchemaRefs) + validateSchemaAllOf() + validateSchemaNot() + validateObjectSchema() + validateArraySchema() + validateStringSchema() + validateBooleanSchema() + validateIntegerSchema() + validateNumberSchema() } // NewEmptyTypeSchemaValidator returns a new TypeSchemaValidator that have no reflect.Type or openapi3.Schema configured yet. -func NewEmptyTypeSchemaValidator(logger utils.Logger) TypeSchemaValidator { +func NewEmptyTypeSchemaValidator() TypeSchemaValidator { return typeSchemaValidatorContext{ - logger: logger, - level: utils.Error, + errors: new([]string), } } // NewTypeSchemaValidator returns a new TypeSchemaValidator that helps validate reflect.Type and openapi3.Schema compatibility using the validation Options. -func NewTypeSchemaValidator(logger utils.Logger, level utils.LogLevel, goType reflect.Type, schema openapi3.Schema) TypeSchemaValidator { +func NewTypeSchemaValidator(goType reflect.Type, schema openapi3.Schema) TypeSchemaValidator { return typeSchemaValidatorContext{ - logger: logger, - level: level, + errors: new([]string), schema: schema, goType: goType, } } -func (c typeSchemaValidatorContext) newLogger() utils.Logger { - return c.logger.NewCounter() -} -func (c typeSchemaValidatorContext) logLevel() utils.LogLevel { - return c.level -} - // typeSchemaValidatorContext an internal struct that implementation TypeSchemaValidator type typeSchemaValidatorContext struct { - logger utils.Logger - level utils.LogLevel + errors *[]string schema openapi3.Schema goType reflect.Type } -func (c typeSchemaValidatorContext) WithLogger(logger utils.Logger) TypeSchemaValidator { - c.logger = logger - return c +func (c typeSchemaValidatorContext) err(format string, args ...any) { + *c.errors = append(*c.errors, fmt.Sprintf(format, args...)) } + func (c typeSchemaValidatorContext) WithType(goType reflect.Type) TypeSchemaValidator { c.goType = goType - c.logger = c.newLogger() return c } func (c typeSchemaValidatorContext) WithSchema(schema openapi3.Schema) TypeSchemaValidator { c.schema = schema - c.logger = c.newLogger() return c } func (c typeSchemaValidatorContext) WithSchemaAndType(schema openapi3.Schema, goType reflect.Type) TypeSchemaValidator { c.schema = schema c.goType = goType - c.logger = c.newLogger() return c } +func (c typeSchemaValidatorContext) Errors() []string { + return *c.errors +} func (c typeSchemaValidatorContext) Validate() error { - if isEmptyInterface(c.goType) { + if isAny(c.goType) { return nil } - if c.goType.Kind() == reflect.Pointer { + if utils.IsMultiType(c.goType) { + if _, err := utils.ExtractMultiTypeTypes(c.goType); err != nil { + c.err(err.Error()) + } + } + if c.goType.Kind() == reflect.Pointer && !utils.IsMultiType(c.goType) { return c.WithType(c.goType.Elem()).Validate() } + // Test global schema validation properties - c.logger.ErrorIfNotNil(c.validateSchemaAllOf()) - c.logger.ErrorIfNotNil(c.validateSchemaOneOf()) - c.logger.ErrorIfNotNil(c.validateSchemaAnyOf()) - c.logger.ErrorIfNotNil(c.validateSchemaNot()) + c.validateSchemaAllOf() + c.validateSchemaNot() + c.matchAllSchemaValidator("oneOf", c.schema.OneOf) + c.matchAllSchemaValidator("anyOf", c.schema.AnyOf) // Test specific schema types validations - switch c.schema.Type { - case objectSchemaType: - c.logger.ErrorIfNotNil(c.validateObjectSchema()) - case arraySchemaType: - c.logger.ErrorIfNotNil(c.validateArraySchema()) - case stringSchemaType: - c.logger.ErrorIfNotNil(c.validateStringSchema()) - case booleanSchemaType: - c.logger.ErrorIfNotNil(c.validateBooleanSchema()) - case numberSchemaType: - c.logger.ErrorIfNotNil(c.validateNumberSchema()) - case integerSchemaType: - c.logger.ErrorIfNotNil(c.validateIntegerSchema()) + c.validateObjectSchema() + c.validateArraySchema() + c.validateStringSchema() + c.validateBooleanSchema() + c.validateNumberSchema() + c.validateIntegerSchema() + + if len(*c.errors) > 0 { + err := fmt.Errorf("%w %s", ErrSchemaIncompatibleWithType, c.goType) + c.err(err.Error()) + return err } - return c.logger.MustHaveNoErrors() -} -func isEmptyInterface(t reflect.Type) bool { - return t.Kind() == reflect.Interface && t.NumMethod() == 0 + return nil } diff --git a/router/schema_validator/schema_validator_test.go b/router/schema_validator/schema_validator_test.go index 605f7b5..02f02bc 100644 --- a/router/schema_validator/schema_validator_test.go +++ b/router/schema_validator/schema_validator_test.go @@ -93,7 +93,14 @@ func arrayTypes(depth int) []reflect.Type { func mapTypes(depth int) []reflect.Type { allTypes := allTypes(depth) comparableTypes := utils.Filter(allTypes, func(t reflect.Type) bool { - return t.Comparable() + switch t.Kind() { + case reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return t.Comparable() + default: + return false + } }) results := make([]reflect.Type, len(allTypes)*len(comparableTypes)) for i, keyType := range comparableTypes { @@ -119,6 +126,7 @@ func expectTypeToBeCompatible(t *testing.T, validator TypeSchemaValidator, testT t.Error(err) } } + func expectTypeToBeIncompatible(t *testing.T, validator TypeSchemaValidator, testType reflect.Type, errTemplate string, args ...any) { if err := validator.WithType(testType).Validate(); err == nil { t.Errorf(errTemplate, args...) @@ -126,9 +134,8 @@ func expectTypeToBeIncompatible(t *testing.T, validator TypeSchemaValidator, tes } func TestSchemaValidatorWithOptions(t *testing.T) { - logger := utils.NewInMemoryLogger() stringSchema := openapi3.NewStringSchema() - validator := NewEmptyTypeSchemaValidator(logger).WithSchema(*stringSchema) + validator := NewEmptyTypeSchemaValidator().WithSchema(*stringSchema) errTemplate := "expect string schema with time format to be %s with %s type" expectTypeToBeCompatible(t, validator, stringType, errTemplate, "compatible", stringType) // omit the string type from all defined test types @@ -143,8 +150,7 @@ func TestSchemaValidatorWithOptions(t *testing.T) { } func emptyValidator() TypeSchemaValidator { - logger := utils.NewInMemoryLogger() - return NewEmptyTypeSchemaValidator(logger) + return NewEmptyTypeSchemaValidator() } func typeValidator(goType reflect.Type) TypeSchemaValidator { return emptyValidator().WithType(goType) diff --git a/router/schema_validator/string_schema_validator.go b/router/schema_validator/string_schema_validator.go index 341c35b..5999305 100644 --- a/router/schema_validator/string_schema_validator.go +++ b/router/schema_validator/string_schema_validator.go @@ -1,40 +1,57 @@ package schema_validator -import ( - "reflect" - "time" +import "github.com/getkin/kin-openapi/openapi3" - "github.com/google/uuid" -) +func (c typeSchemaValidatorContext) validateStringSchema() { + if c.schema.Type == openapi3.TypeString && !isSerializedFromString(c.goType) { + c.err(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) + } -var timeType = reflect.TypeOf(new(time.Time)).Elem() -var uuidType = reflect.TypeOf(new(uuid.UUID)).Elem() + // if schema type is not string and not empty + if !isSchemaTypeStringOrEmpty(c.schema) { + // and go type is a string + if isString(c.goType) { + // can't have a go type string for the remaining schema types: boolean, number, integer, array, object + c.err(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) + } -func (c typeSchemaValidatorContext) validateStringSchema() error { - l := c.newLogger() - if c.schema.Type != stringSchemaType { - return nil + // if schema type is not string and not empty other string validations has no meaning. return early. + return } - switch c.schema.Format { - case "": - if c.goType.Kind() != reflect.String { - l.Logf(c.level, schemaTypeIsIncompatibleWithType(c.schema, c.goType)) - } - case uuidFormat: - if c.goType.Kind() != reflect.String && !uuidType.ConvertibleTo(c.goType) { - l.Logf(c.level, schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) + + // if schema format is empty return early + // pattern, minLength and maxLength have no effect on type validation. + if c.schema.Format == "" { + return + } + + // if schema format is "byte" expect type to be compatible with []byte + if c.schema.Format == byteFormat { + if (c.schema.Type == openapi3.TypeString || isSerializedFromString(c.goType)) && !isSliceOfBytes(c.goType) { + c.err(schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) } - case dateTimeFormat, timeFormat: - if c.goType.Kind() != reflect.String && !timeType.ConvertibleTo(c.goType) { - l.Logf(c.level, schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) + return + } + + // if schema format is "uuid" expect type to be compatible with UUID + if c.schema.Format == uuidFormat { + if (c.schema.Type == openapi3.TypeString || isSerializedFromString(c.goType)) && !isUUIDCompatible(c.goType) { + c.err(schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) } - // TODO: add support for more formats compatible types (dateTimeFormat, dateFormat, durationFormat, etc.) - case dateFormat, durationFormat, emailFormat, idnEmailFormat, hostnameFormat, - idnHostnameFormat, ipv4Format, ipv6Format, uriFormat, uriReferenceFormat, iriFormat, iriReferenceFormat, - uriTemplateFormat, jsonPointerFormat, relativeJsonPointerFormat, regexFormat, passwordFormat: - if c.goType.Kind() != reflect.String { - l.Logf(c.level, schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) + return + } + + // if schema format is "date-time" or "time" expect type to be compatible with Time + if isTimeFormat(c.schema) { + if (c.schema.Type == openapi3.TypeString || isSerializedFromString(c.goType)) && !isTimeCompatible(c.goType) { + c.err(schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) } + return + } + + // if schema format is any other string format expect go type to be string + if isSchemaStringFormat(c.schema) && (c.schema.Type == openapi3.TypeString || isSerializedFromString(c.goType)) && !isString(c.goType) { + c.err(schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) + return } - return formatMustHaveNoError(l.MustHaveNoErrors(), c.schema.Type, c.goType) } diff --git a/router/schema_validator/string_schema_validator_test.go b/router/schema_validator/string_schema_validator_test.go index 41e0646..b1b181b 100644 --- a/router/schema_validator/string_schema_validator_test.go +++ b/router/schema_validator/string_schema_validator_test.go @@ -7,10 +7,21 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/google/uuid" + "github.com/stretchr/testify/require" "github.com/piiano/cellotape/router/utils" ) +func TestStringSchemaValidatorWithByteFormat(t *testing.T) { + stringSchema := openapi3.NewStringSchema() + stringSchema.Format = "byte" + + validator := schemaValidator(*stringSchema) + errTemplate := "expect string schema to be compatible with %s type" + bytes := reflect.TypeOf([]byte{}) + expectTypeToBeCompatible(t, validator, bytes, errTemplate, bytes) +} + func TestStringSchemaValidatorPassForStringType(t *testing.T) { stringSchema := openapi3.NewStringSchema() validator := schemaValidator(*stringSchema) @@ -21,12 +32,15 @@ func TestStringSchemaValidatorPassForStringType(t *testing.T) { // according to the spec the string validation properties should apply only when the type is set to string func TestStringSchemaValidatorWithUntypedSchema(t *testing.T) { untypedSchemaWithUUIDFormat := openapi3.NewSchema().WithFormat(uuidFormat) - validator := schemaValidator(*untypedSchemaWithUUIDFormat) - for _, validType := range types { + + otherNonStringTypes := utils.Filter(types, func(t reflect.Type) bool { + return t != sliceOfBytesType && t != timeType + }) + + for _, validType := range otherNonStringTypes { t.Run(validType.String(), func(t *testing.T) { - if err := validator.WithType(validType).validateStringSchema(); err != nil { - t.Errorf("expect untyped schema to be compatible with %s type", validType) - } + err := schemaValidator(*untypedSchemaWithUUIDFormat).WithType(validType).Validate() + require.NoErrorf(t, err, "expect untyped schema to be compatible with %s type", validType) }) } } @@ -66,36 +80,39 @@ func TestUUIDFormatSchemaValidator(t *testing.T) { func TestTimeFormatSchemaValidator(t *testing.T) { timeSchema := openapi3.NewStringSchema().WithFormat(timeFormat) - validator := schemaValidator(*timeSchema) + errTemplate := "expect string schema with time format to be %s with %s type" - timeType := reflect.TypeOf(time.Now()) - expectTypeToBeCompatible(t, validator, timeType, errTemplate, "compatible", timeType) - expectTypeToBeCompatible(t, validator, stringType, errTemplate, "compatible", stringType) - // omit the uuid compatible types from all defined test types - var nonTimeCompatibleTypes = utils.Filter(types, func(t reflect.Type) bool { - return t != timeType && t != stringType && t != reflect.PointerTo(timeType) && t != reflect.PointerTo(stringType) - }) - for _, nonTimeCompatibleType := range nonTimeCompatibleTypes { - t.Run(nonTimeCompatibleType.String(), func(t *testing.T) { - expectTypeToBeIncompatible(t, validator, nonTimeCompatibleType, errTemplate, "incompatible", nonTimeCompatibleType) + timeCompatibleType := utils.NewSet(utils.Map([]reflect.Type{ + timeType, stringType, reflect.PointerTo(timeType), reflect.PointerTo(stringType), + }, reflect.Type.String)...) + + for _, goType := range append(types, timeType, reflect.PointerTo(timeType)) { + t.Run(goType.String(), func(t *testing.T) { + err := schemaValidator(*timeSchema).WithType(goType).Validate() + valid := timeCompatibleType.Has(goType.String()) + if valid { + require.NoErrorf(t, err, errTemplate, "compatible", goType) + } else { + require.Errorf(t, err, errTemplate, "incompatible", goType) + } }) } } func TestDateTimeFormatSchemaValidator(t *testing.T) { timeSchema := openapi3.NewStringSchema().WithFormat(dateTimeFormat) - validator := schemaValidator(*timeSchema) + errTemplate := "expect string schema with time format to be %s with %s type" timeType := reflect.TypeOf(time.Now()) - expectTypeToBeCompatible(t, validator, timeType, errTemplate, "compatible", timeType) - expectTypeToBeCompatible(t, validator, stringType, errTemplate, "compatible", stringType) + expectTypeToBeCompatible(t, schemaValidator(*timeSchema), timeType, errTemplate, "compatible", timeType) + expectTypeToBeCompatible(t, schemaValidator(*timeSchema), stringType, errTemplate, "compatible", stringType) // omit the uuid compatible types from all defined test types var nonTimeCompatibleTypes = utils.Filter(types, func(t reflect.Type) bool { return t != timeType && t != stringType && t != reflect.PointerTo(timeType) && t != reflect.PointerTo(stringType) }) for _, nonTimeCompatibleType := range nonTimeCompatibleTypes { t.Run(nonTimeCompatibleType.String(), func(t *testing.T) { - expectTypeToBeIncompatible(t, validator, nonTimeCompatibleType, errTemplate, "incompatible", nonTimeCompatibleType) + expectTypeToBeIncompatible(t, schemaValidator(*timeSchema), nonTimeCompatibleType, errTemplate, "incompatible", nonTimeCompatibleType) }) } } diff --git a/router/utils/logger.go b/router/utils/logger.go index ce008f7..9ce1757 100644 --- a/router/utils/logger.go +++ b/router/utils/logger.go @@ -142,7 +142,7 @@ func (l *logger) Log(level LogLevel, arg any) { write := func(string) {} if l.level != Off && l.level >= level { - write = func(levelStr string) { fmt.Fprintln(l.output, levelStr, arg) } + write = func(levelStr string) { fmt.Fprintln(l.output, levelStr, strings.Trim(fmt.Sprint(arg), "\n")) } } switch level { case Info: diff --git a/router/utils/multitype.go b/router/utils/multitype.go new file mode 100644 index 0000000..f7c0211 --- /dev/null +++ b/router/utils/multitype.go @@ -0,0 +1,140 @@ +package utils + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" +) + +var ErrInvalidUseOfMultiType = errors.New("invalid use of MultiType") + +var multiTypeReflectType = GetType[multiType]() + +func IsMultiType(t reflect.Type) bool { + return t.Implements(multiTypeReflectType) +} + +func ExtractMultiTypeTypes(mtType reflect.Type) ([]reflect.Type, error) { + multiTypeTypes := reflect.New(mtType).Elem().MethodByName("MultiTypeTypes") + + returnValues := multiTypeTypes.Call([]reflect.Value{}) + + if err := returnValues[1].Interface(); err != nil { + return nil, err.(error) + } + + return returnValues[0].Interface().([]reflect.Type), nil +} + +type multiType interface { + MultiTypeTypes() ([]reflect.Type, error) + json.Marshaler + json.Unmarshaler +} + +type MultiType[T any] struct { + Values T +} + +func (o *MultiType[T]) MultiTypeTypes() ([]reflect.Type, error) { + fields, err := o.fields() + if err != nil { + return nil, err + } + + return Map(fields, func(t reflect.StructField) reflect.Type { + return t.Type + }), nil +} + +func (o *MultiType[T]) fields() ([]reflect.StructField, error) { + structType := GetType[T]() + if structType.Kind() != reflect.Struct { + return []reflect.StructField{}, + fmt.Errorf("%w. expecting generic argument to be a struct", + ErrInvalidUseOfMultiType) + } + fieldsMap := StructKeys(structType, "") + for _, field := range fieldsMap { + if field.Type.Kind() != reflect.Pointer { + return []reflect.StructField{}, + fmt.Errorf("%w. field %q should be a pointer", + ErrInvalidUseOfMultiType, field.Name) + } + } + + fields := Map(Entries(fieldsMap), func(e Entry[string, reflect.StructField]) reflect.StructField { + return e.Value + }) + + if len(fields) == 0 { + return nil, fmt.Errorf("%w. must have at least one field", ErrInvalidUseOfMultiType) + } + + uniqueTypes := NewSet(Map(fields, func(t reflect.StructField) string { + return t.Type.String() + })...) + + if len(uniqueTypes) != len(fields) { + return nil, fmt.Errorf("%w. each field of MultiType must be of a different type", ErrInvalidUseOfMultiType) + } + + return fields, nil +} + +func (o *MultiType[T]) MarshalJSON() ([]byte, error) { + fields, err := o.fields() + if err != nil { + return nil, fmt.Errorf("can't marshal value to JSON due to %w", ErrInvalidUseOfMultiType) + } + + value := reflect.ValueOf(o.Values) + + fields = Filter(fields, func(field reflect.StructField) bool { + return !value.FieldByIndex(field.Index).IsNil() + }) + + if len(fields) == 0 { + return nil, &json.UnsupportedValueError{ + Value: reflect.ValueOf(o), + Str: "non of MultiType fields is set", + } + } + + if len(fields) > 1 { + return nil, &json.UnsupportedValueError{ + Value: reflect.ValueOf(o), + Str: "more than one field of MultiType is set", + } + } + + fieldValue := value.FieldByIndex(fields[0].Index) + + return json.Marshal(fieldValue.Interface()) +} + +func (o *MultiType[T]) UnmarshalJSON(bytes []byte) error { + fields, err := o.fields() + if err != nil { + return fmt.Errorf("can't unmarshal JSON to value due to %w", ErrInvalidUseOfMultiType) + } + structValue := reflect.ValueOf(&o.Values).Elem() + + for _, field := range fields { + fieldValue := structValue.FieldByIndex(field.Index) + value := fieldValue.Addr().Interface() + if err = json.Unmarshal(bytes, value); err == nil { + return nil + } + fieldValue.Set(reflect.Zero(fieldValue.Type())) + } + + unmarshalTypeError := &json.UnmarshalTypeError{} + if ok := errors.As(err, &unmarshalTypeError); ok { + unmarshalTypeError.Type = reflect.TypeOf(o) + return unmarshalTypeError + } + + return err +} diff --git a/router/utils/multitype_test.go b/router/utils/multitype_test.go new file mode 100644 index 0000000..688a3a9 --- /dev/null +++ b/router/utils/multitype_test.go @@ -0,0 +1,239 @@ +package utils + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +type Answer struct { + OneAnswer *string + ManyAnswers *[]string + UltimateAnswer *int +} + +func TestBadMultiType(t *testing.T) { + testCases := []struct { + name string + mt multiType + }{ + { + name: "non struct type", + mt: &MultiType[string]{}, + }, + { + name: "non pointer field", + mt: &MultiType[struct { + A string + B *int + }]{}, + }, + { + name: "empty struct", + mt: &MultiType[struct{}]{}, + }, + { + name: "non unique field type", + mt: &MultiType[struct { + A *string + B *string + }]{}, + }, + { + name: "non unique field type with anonymous structs", + mt: &MultiType[struct { + A *struct { + A string + } + B *struct { + A string + } + }]{}, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + mtType := reflect.TypeOf(testCase.mt) + require.True(t, IsMultiType(mtType)) + + _, err := ExtractMultiTypeTypes(mtType) + require.ErrorIs(t, err, ErrInvalidUseOfMultiType) + + err = json.Unmarshal([]byte(`""`), &testCase.mt) + require.ErrorIs(t, err, ErrInvalidUseOfMultiType) + + _, err = json.Marshal(testCase.mt) + require.ErrorIs(t, err, ErrInvalidUseOfMultiType) + }) + } +} + +func TestMultiTypeTypes(t *testing.T) { + testCases := []struct { + name string + mt multiType + types []reflect.Type + }{ + { + name: "answer", + mt: &MultiType[Answer]{}, + types: []reflect.Type{ + GetType[*string](), + GetType[*[]string](), + GetType[*int](), + }, + }, + { + name: "anonymous struct", + mt: &MultiType[struct { + A *bool + B *string + C *struct { + C1 []string + } + }]{}, + types: []reflect.Type{ + GetType[*bool](), + GetType[*string](), + GetType[*struct{ C1 []string }](), + }, + }, + { + name: "array of single of same struct", + mt: &MultiType[struct { + A *struct { + A string + } + B *[]struct { + A string + } + }]{}, + types: []reflect.Type{GetType[*struct { + A string + }](), GetType[*[]struct { + A string + }]()}, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + mtType := reflect.TypeOf(testCase.mt) + require.True(t, IsMultiType(mtType)) + types, err := ExtractMultiTypeTypes(mtType) + require.NoError(t, err) + require.ElementsMatch(t, testCase.types, types) + }) + } +} + +func TestMarshalMultiType(t *testing.T) { + testCases := []struct { + name string + input MultiType[Answer] + output string + }{ + { + name: "OneAnswer", + input: MultiType[Answer]{ + Values: Answer{ + OneAnswer: Ptr("foo"), + }, + }, + output: `"foo"`, + }, + { + name: "ManyAnswers", + input: MultiType[Answer]{ + Values: Answer{ + ManyAnswers: &[]string{"foo", "bar"}, + }, + }, + output: `["foo", "bar"]`, + }, + { + name: "UltimateAnswer", + input: MultiType[Answer]{ + Values: Answer{ + UltimateAnswer: Ptr(42), + }, + }, + output: `42`, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + bytes, err := json.Marshal(&test.input) + require.NoError(t, err) + require.JSONEq(t, test.output, string(bytes)) + }) + } +} + +func TestMarshalMultiTypeError(t *testing.T) { + multi := &MultiType[Answer]{} + + _, err := json.Marshal(multi) + require.Error(t, err) + require.IsType(t, &json.MarshalerError{}, err) + + multi = &MultiType[Answer]{ + Values: Answer{ + UltimateAnswer: Ptr(42), + OneAnswer: Ptr("foo"), + }, + } + + _, err = json.Marshal(multi) + require.Error(t, err) + require.IsType(t, &json.MarshalerError{}, err) +} + +func TestUnmarshalMultiType(t *testing.T) { + multi := &MultiType[Answer]{} + + err := json.Unmarshal([]byte(`["foo", "bar"]`), multi) + require.NoError(t, err) + require.NotNil(t, multi.Values.ManyAnswers) + require.ElementsMatch(t, []string{"foo", "bar"}, *multi.Values.ManyAnswers) + require.Nil(t, multi.Values.OneAnswer) + require.Nil(t, multi.Values.UltimateAnswer) + + multi = &MultiType[Answer]{} + + err = json.Unmarshal([]byte(`"foo"`), multi) + require.NoError(t, err) + require.NotNil(t, multi.Values.OneAnswer) + require.Equal(t, "foo", *multi.Values.OneAnswer) + require.Nil(t, multi.Values.ManyAnswers) + require.Nil(t, multi.Values.UltimateAnswer) + + multi = &MultiType[Answer]{} + + err = json.Unmarshal([]byte(`42`), multi) + require.NoError(t, err) + require.NotNil(t, multi.Values.UltimateAnswer) + require.Equal(t, 42, *multi.Values.UltimateAnswer) + require.Nil(t, multi.Values.ManyAnswers) + require.Nil(t, multi.Values.OneAnswer) + + multi = &MultiType[Answer]{} + + err = json.Unmarshal([]byte(`true`), multi) + require.Error(t, err) + require.IsType(t, &json.UnmarshalTypeError{}, err) + require.Equal(t, &json.UnmarshalTypeError{ + Value: "bool", + Type: reflect.TypeOf(multi), + Offset: 4, + }, err) + + multi = &MultiType[Answer]{} + + err = json.Unmarshal([]byte(`}`), multi) + require.Error(t, err) + require.IsType(t, &json.SyntaxError{}, err) +} diff --git a/router/utils/slices.go b/router/utils/slices.go index 61bc21e..8179d3c 100644 --- a/router/utils/slices.go +++ b/router/utils/slices.go @@ -53,3 +53,7 @@ func ConcatSlices[T any](slices ...[]T) []T { } return target } + +func Ptr[T any](value T) *T { + return &value +} diff --git a/router/utils/types.go b/router/utils/types.go new file mode 100644 index 0000000..ac5da20 --- /dev/null +++ b/router/utils/types.go @@ -0,0 +1,46 @@ +package utils + +import ( + "reflect" + "strings" +) + +// GetType returns reflect.Type of the generic parameter it receives. +func GetType[T any]() reflect.Type { return reflect.TypeOf(new(T)).Elem() } + +// Nil represents an empty type. +// You can use it with the HandlerFunc generic parameters to declare no Request with no request body, no path or query +// params, or responses with no response body. +type Nil *uintptr + +// NilType represent the type of Nil. +var NilType = GetType[Nil]() + +const ignoreFieldTagValue = "-" + +// StructKeys returns a map of "key" -> "field" for all the fields in the struct. +// Key is the field tag if exists or the field name otherwise. +// Field is the reflect.StructField of the field. +// +// Unexported fields or fields with tag value of "-" are ignored. +// +// StructKeys will recursively traverse all the embedded structs and return their fields as well. +func StructKeys(structType reflect.Type, tag string) map[string]reflect.StructField { + if structType == nil || structType == NilType { + return map[string]reflect.StructField{} + } + return FromEntries(ConcatSlices(Map(Filter(reflect.VisibleFields(structType), func(field reflect.StructField) bool { + return !field.Anonymous && field.IsExported() && field.Tag.Get(tag) != ignoreFieldTagValue + }), func(field reflect.StructField) Entry[string, reflect.StructField] { + name := field.Tag.Get(tag) + name, _, _ = strings.Cut(name, ",") + if name == "" { + name = field.Name + } + return Entry[string, reflect.StructField]{Key: name, Value: field} + }), ConcatSlices(Map(Filter(reflect.VisibleFields(structType), func(field reflect.StructField) bool { + return field.Anonymous && field.IsExported() && field.Tag.Get(tag) != ignoreFieldTagValue + }), func(field reflect.StructField) []Entry[string, reflect.StructField] { + return Entries(StructKeys(field.Type, tag)) + })...))) +} diff --git a/router/validations.go b/router/validations.go index bb9947d..ebeaf5b 100644 --- a/router/validations.go +++ b/router/validations.go @@ -3,7 +3,6 @@ package router import ( "fmt" "reflect" - "strings" "github.com/getkin/kin-openapi/openapi3" @@ -12,11 +11,10 @@ import ( ) const ( - pathParamInValue = "path" - pathParamFieldTag = "uri" - queryParamInValue = "query" - queryParamFieldTag = "form" - ignoreFieldTagValue = "-" + pathParamInValue = "path" + pathParamFieldTag = "uri" + queryParamInValue = "query" + queryParamFieldTag = "form" ) // validateOpenAPIRouter validates the entire OpenAPI Router structure built with the builder with the spec. @@ -95,7 +93,7 @@ func validateOperation(oa openapi, operation operation) error { func validateHandleAllPathParams(oa openapi, behaviour Behaviour, operation operation, specOp SpecOperation) utils.LogCounters { handlers := append(operation.handlers, operation.handler) declaredParams := utils.NewSet[string](utils.ConcatSlices[string](utils.Map(handlers, func(h handler) []string { - return utils.Keys(structKeys(h.request.pathParams, pathParamFieldTag)) + return utils.Keys(utils.StructKeys(h.request.pathParams, pathParamFieldTag)) })...)...) return validateHandleAllParams(oa, behaviour, operation, specOp, pathParamInValue, declaredParams) } @@ -104,7 +102,7 @@ func validateHandleAllPathParams(oa openapi, behaviour Behaviour, operation oper func validateHandleAllQueryParams(oa openapi, behaviour Behaviour, operation operation, specOp SpecOperation) utils.LogCounters { handlers := append(operation.handlers, operation.handler) declaredParams := utils.NewSet[string](utils.ConcatSlices[string](utils.Map(handlers, func(h handler) []string { - return utils.Keys(structKeys(h.request.queryParams, queryParamFieldTag)) + return utils.Keys(utils.StructKeys(h.request.queryParams, queryParamFieldTag)) })...)...) return validateHandleAllParams(oa, behaviour, operation, specOp, queryParamInValue, declaredParams) } @@ -152,7 +150,7 @@ func validateRequestBodyType(oa openapi, behaviour Behaviour, handler handler, s l := oa.logger() level := utils.LogLevel(behaviour) bodyType := handler.request.requestBody - if bodyType == nilType { + if bodyType == utils.NilType { return utils.LogCounters{} } if specBody == nil { @@ -190,11 +188,13 @@ func validateQueryParamsType(oa openapi, behaviour Behaviour, handler handler, s func validateParamsType(oa openapi, behaviour Behaviour, in string, tag string, paramsType reflect.Type, specParameters openapi3.Parameters, operationId string) utils.LogCounters { l := oa.logger() level := utils.LogLevel(behaviour) - if paramsType == nilType { + if paramsType == utils.NilType { return utils.LogCounters{} } - validator := schema_validator.NewTypeSchemaValidator(l, level, nilType, openapi3.Schema{}) - for name, field := range structKeys(paramsType, tag) { + + validator := schema_validator.NewTypeSchemaValidator(utils.NilType, openapi3.Schema{}) + + for name, field := range utils.StructKeys(paramsType, tag) { specParameter := specParameters.GetByInAndName(in, name) if specParameter == nil { l.Logf(level, paramDefinedByHandlerButMissingInSpec(in, name, paramsType, operationId)) @@ -203,31 +203,14 @@ func validateParamsType(oa openapi, behaviour Behaviour, in string, tag string, // TODO: schema validator check object schemas with json keys if err := validator.WithType(field.Type).WithSchema(*specParameter.Schema.Value).Validate(); err != nil { l.Logf(level, incompatibleParamType(operationId, in, name, field.Name, field.Type)) + for _, errMessage := range validator.Errors() { + l.Log(level, errMessage) + } } } return l.Counters() } -func structKeys(structType reflect.Type, tag string) map[string]reflect.StructField { - if structType == nil || structType == nilType { - return map[string]reflect.StructField{} - } - return utils.FromEntries(utils.ConcatSlices(utils.Map(utils.Filter(reflect.VisibleFields(structType), func(field reflect.StructField) bool { - return !field.Anonymous && field.IsExported() && field.Tag.Get(tag) != ignoreFieldTagValue - }), func(field reflect.StructField) utils.Entry[string, reflect.StructField] { - name := field.Tag.Get(tag) - name, _, _ = strings.Cut(name, ",") - if name == "" { - name = field.Name - } - return utils.Entry[string, reflect.StructField]{Key: name, Value: field} - }), utils.ConcatSlices(utils.Map(utils.Filter(reflect.VisibleFields(structType), func(field reflect.StructField) bool { - return field.Anonymous && field.IsExported() && field.Tag.Get(tag) != ignoreFieldTagValue - }), func(field reflect.StructField) []utils.Entry[string, reflect.StructField] { - return utils.Entries(structKeys(field.Type, tag)) - })...))) -} - // validateResponseTypes check that all responses declared on a handler are available on the spec with a compatible schema. // a handler does not have to declare and handle all possible responses defined in the spec, but it can not declare responses which are not defined. func validateResponseTypes(oa openapi, behaviour Behaviour, handler handler, specOperation *openapi3.Operation, operationId string) utils.LogCounters { diff --git a/router/validations_test.go b/router/validations_test.go index 104b7d7..01b406c 100644 --- a/router/validations_test.go +++ b/router/validations_test.go @@ -1,6 +1,7 @@ package router import ( + "bytes" "reflect" "testing" @@ -13,7 +14,7 @@ import ( func TestValidateContentTypes(t *testing.T) { err := validateContentTypes(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), }, utils.NewSet[string]()) require.NoError(t, err) @@ -21,7 +22,7 @@ func TestValidateContentTypes(t *testing.T) { func TestValidateContentTypesWithJSONContentType(t *testing.T) { err := validateContentTypes(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), spec: OpenAPISpec(openapi3.T{ Paths: openapi3.Paths{ @@ -40,7 +41,7 @@ func TestValidateContentTypesWithJSONContentType(t *testing.T) { func TestValidateContentTypesWithExcludedOperation(t *testing.T) { err := validateContentTypes(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), spec: OpenAPISpec(openapi3.T{ Paths: openapi3.Paths{ @@ -60,7 +61,7 @@ func TestValidateContentTypesWithExcludedOperation(t *testing.T) { func TestValidateContentTypesErrorWithMissingJSONContentType(t *testing.T) { err := validateContentTypes(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: ContentTypes{}, spec: OpenAPISpec(openapi3.T{ Paths: openapi3.Paths{ @@ -79,7 +80,7 @@ func TestValidateContentTypesErrorWithMissingJSONContentType(t *testing.T) { func TestValidateHandleAllPathParams(t *testing.T) { counter := validateHandleAllPathParams(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), }, PropagateError, operation{ handler: handler{ request: requestTypes{ @@ -106,7 +107,7 @@ func TestValidateHandleAllPathParams(t *testing.T) { func TestValidateHandleAllQueryParams(t *testing.T) { counter := validateHandleAllQueryParams(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), }, PropagateError, operation{ handler: handler{ request: requestTypes{ @@ -133,7 +134,7 @@ func TestValidateHandleAllQueryParams(t *testing.T) { func TestValidateHandleAllResponses(t *testing.T) { counter := validateHandleAllResponses(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), }, PropagateError, operation{ handler: handler{ @@ -156,7 +157,7 @@ func TestValidateHandleAllResponses(t *testing.T) { func TestValidateHandleAllResponsesError(t *testing.T) { counter := validateHandleAllResponses(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), }, PropagateError, operation{ handler: handler{ @@ -179,7 +180,7 @@ func TestValidateHandleAllResponsesError(t *testing.T) { func TestValidateHandleAllResponsesInvalidStatusError(t *testing.T) { counter := validateHandleAllResponses(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), }, PropagateError, operation{ handler: handler{ @@ -202,7 +203,7 @@ func TestValidateHandleAllResponsesInvalidStatusError(t *testing.T) { func TestValidateHandleAllResponsesMissingStatusError(t *testing.T) { counter := validateHandleAllResponses(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), }, PropagateError, operation{ handler: handler{ @@ -219,7 +220,7 @@ func TestValidateHandleAllResponsesMissingStatusError(t *testing.T) { func TestValidateRequestBodyType(t *testing.T) { counter := validateRequestBodyType(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), }, PropagateError, handler{ request: requestTypes{ @@ -242,7 +243,7 @@ func TestValidateRequestBodyType(t *testing.T) { func TestValidateRequestBodyTypeIgnoreMissingContentType(t *testing.T) { counter := validateRequestBodyType(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), }, PropagateError, handler{ request: requestTypes{ requestBody: reflect.TypeOf(""), @@ -264,7 +265,7 @@ func TestValidateRequestBodyTypeIgnoreMissingContentType(t *testing.T) { func TestValidateRequestBodyTypeErrorWithNoBodyInSpec(t *testing.T) { counter := validateRequestBodyType(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), }, PropagateError, handler{ request: requestTypes{ @@ -287,7 +288,7 @@ func TestValidateRequestBodyTypeErrorWithNoBodyInSpec(t *testing.T) { func TestValidateRequestBodyTypeErrorWithIcompatibleSchema(t *testing.T) { counter := validateRequestBodyType(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), }, PropagateError, handler{ request: requestTypes{ @@ -303,7 +304,7 @@ func TestValidateQueryParamsType(t *testing.T) { assert.Equal(t, 0, counter.Errors) assert.Equal(t, 0, counter.Warnings) counter = validateQueryParamsType(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), }, PropagateError, handler{ request: requestTypes{ queryParams: reflect.TypeOf(struct { @@ -324,7 +325,7 @@ func TestValidateCollidingEmbeddedQueryQueryParamsType(t *testing.T) { assert.Equal(t, 0, counter.Errors) assert.Equal(t, 0, counter.Warnings) counter = validateHandleAllQueryParams(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), }, PropagateError, operation{ handler: handler{ request: requestTypes{ @@ -352,7 +353,7 @@ func TestValidateQueryParamsTypeFailWhenMissingInSpec(t *testing.T) { assert.Equal(t, 0, counter.Errors) assert.Equal(t, 0, counter.Warnings) counter = validateQueryParamsType(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), }, PropagateError, handler{ request: requestTypes{ queryParams: reflect.TypeOf(struct { @@ -369,7 +370,7 @@ func TestValidateQueryParamsTypeFailWhenIncompatibleType(t *testing.T) { assert.Equal(t, 0, counter.Errors) assert.Equal(t, 0, counter.Warnings) counter = validateQueryParamsType(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), }, PropagateError, handler{ request: requestTypes{ queryParams: reflect.TypeOf(struct { @@ -381,7 +382,7 @@ func TestValidateQueryParamsTypeFailWhenIncompatibleType(t *testing.T) { Value: openapi3.NewQueryParameter("foo").WithSchema(openapi3.NewIntegerSchema()), }, }, "") - assert.Equal(t, 1, counter.Errors) + assert.Equal(t, 4, counter.Errors) assert.Equal(t, 0, counter.Warnings) } @@ -390,7 +391,7 @@ func TestValidatePathParamsType(t *testing.T) { assert.Equal(t, 0, counter.Errors) assert.Equal(t, 0, counter.Warnings) counter = validatePathParamsType(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), }, PropagateError, handler{ request: requestTypes{ pathParams: reflect.TypeOf(struct { @@ -411,7 +412,7 @@ func TestValidatePathParamsTypeFailWhenMissingInSpec(t *testing.T) { assert.Equal(t, 0, counter.Errors) assert.Equal(t, 0, counter.Warnings) counter = validatePathParamsType(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), }, PropagateError, handler{ request: requestTypes{ pathParams: reflect.TypeOf(struct { @@ -427,8 +428,9 @@ func TestValidatePathParamsTypeFailWhenIncompatibleType(t *testing.T) { counter := validatePathParamsType(openapi{}, PropagateError, handler{}, openapi3.Parameters{}, "") assert.Equal(t, 0, counter.Errors) assert.Equal(t, 0, counter.Warnings) + counter = validatePathParamsType(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), }, PropagateError, handler{ request: requestTypes{ @@ -441,34 +443,54 @@ func TestValidatePathParamsTypeFailWhenIncompatibleType(t *testing.T) { Value: openapi3.NewPathParameter("foo").WithSchema(openapi3.NewIntegerSchema()), }, }, "") - assert.Equal(t, 1, counter.Errors) + assert.Equal(t, 4, counter.Errors) assert.Equal(t, 0, counter.Warnings) } func TestStructKeys(t *testing.T) { - structType := reflect.TypeOf(struct { + structType := utils.GetType[struct { Field1 string `json:"field1"` Field2 int `json:",omitempty"` Field3 bool - }{}) - keys := structKeys(structType, "json") + }]() + keys := utils.StructKeys(structType, "json") assert.Equal(t, map[string]reflect.StructField{ "field1": structType.Field(0), "Field2": structType.Field(1), "Field3": structType.Field(2), }, keys) - structType2 := reflect.TypeOf(struct { + structType2 := utils.GetType[struct { Field1 string `form:"field1"` Field2 int `form:",omitempty"` Field3 bool - }{}) - keys2 := structKeys(structType2, "form") + }]() + keys2 := utils.StructKeys(structType2, "form") assert.Equal(t, map[string]reflect.StructField{ "field1": structType2.Field(0), "Field2": structType2.Field(1), "Field3": structType2.Field(2), }, keys2) + + type Embedded2 struct { + Field3 string `tagName:"field3"` + } + type Embedded struct { + Embedded2 + IgnoredField string `tagName:"-"` + } + structType3 := utils.GetType[struct { + Field1 string `tagName:"field1"` + Field2 int `tagName:",omitempty"` + Embedded + }]() + keys3 := utils.StructKeys(structType3, "tagName") + assert.Equal(t, map[string]reflect.StructField{ + "field1": structType3.Field(0), + "Field2": structType3.Field(1), + "field3": structType3.FieldByIndex([]int{2, 0, 0}), + }, keys3) + assert.Equal(t, keys3["field3"].Name, "Field3") } func TestValidateResponseTypes(t *testing.T) { @@ -477,6 +499,7 @@ func TestValidateResponseTypes(t *testing.T) { assert.Equal(t, 0, counter.Warnings) counter = validateResponseTypes(openapi{ + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), }, PropagateError, handler{ responses: handlerResponses{ @@ -513,7 +536,7 @@ func TestValidateResponseTypesIgnoreMissingContentType(t *testing.T) { func TestValidateResponseTypesMissingStatusErr(t *testing.T) { counter := validateResponseTypes(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), }, PropagateError, handler{ responses: handlerResponses{ @@ -531,7 +554,7 @@ func TestValidateResponseTypesMissingStatusErr(t *testing.T) { func TestValidateResponseTypesIncompatibleTypeErr(t *testing.T) { counter := validateResponseTypes(openapi{ - options: DefaultOptions(), + options: DefaultTestOptions(), contentTypes: DefaultContentTypes(), }, PropagateError, handler{ responses: handlerResponses{ @@ -547,6 +570,100 @@ func TestValidateResponseTypesIncompatibleTypeErr(t *testing.T) { assert.Equal(t, 0, counter.Warnings) } +func TestImplementingExcludedOperationErr(t *testing.T) { + spec := NewSpec() + testOperation := openapi3.NewOperation() + testOperation.OperationID = "test" + spec.Paths = openapi3.Paths{ + "/test": &openapi3.PathItem{ + Get: testOperation, + }, + } + + options := DefaultTestOptions() + options.ExcludeOperations = []string{"test"} + + err := validateOpenAPIRouter(&openapi{ + spec: spec, + options: options, + }, []operation{ + { + id: "test", + handler: handler{ + request: requestTypes{ + requestBody: utils.NilType, + pathParams: utils.NilType, + queryParams: utils.NilType, + }, + }, + }, + }) + require.Error(t, err) +} + +func TestImplementingSameOperationMultipleTimesErr(t *testing.T) { + spec := NewSpec() + testOperation := openapi3.NewOperation() + testOperation.OperationID = "test" + spec.Paths = openapi3.Paths{ + "/test": &openapi3.PathItem{ + Get: testOperation, + }, + } + + opImpl := operation{ + id: "test", + handler: handler{ + request: requestTypes{ + requestBody: utils.NilType, + pathParams: utils.NilType, + queryParams: utils.NilType, + }, + }, + } + err := validateOpenAPIRouter(&openapi{ + spec: spec, + options: DefaultTestOptions(), + }, []operation{opImpl, opImpl}) + require.Error(t, err) +} + +func TestMissingOperationImplementationErr(t *testing.T) { + spec := NewSpec() + testOperation := openapi3.NewOperation() + testOperation.OperationID = "test" + spec.Paths = openapi3.Paths{ + "/test": &openapi3.PathItem{ + Get: testOperation, + }, + } + + err := validateOpenAPIRouter(&openapi{ + spec: spec, + options: DefaultTestOptions(), + }, []operation{}) + require.Error(t, err) +} + +func TestMissingOperationInSpecErr(t *testing.T) { + err := validateOpenAPIRouter(&openapi{ + spec: NewSpec(), + options: DefaultTestOptions(), + }, []operation{ + { + id: "test", + handler: handler{ + request: requestTypes{ + requestBody: utils.NilType, + pathParams: utils.NilType, + queryParams: utils.NilType, + }, + }, + }, + }) + require.Error(t, err) +} + func testSpecResponse(status string, contentType string, schema *openapi3.Schema) map[string]*openapi3.ResponseRef { return map[string]*openapi3.ResponseRef{ status: { @@ -562,3 +679,9 @@ func testSpecResponse(status string, contentType string, schema *openapi3.Schema }, } } + +func DefaultTestOptions() Options { + options := DefaultOptions() + options.LogOutput = bytes.NewBuffer([]byte{}) + return options +}