From e9724a409c7db274bc65a796f66441a16fbb8365 Mon Sep 17 00:00:00 2001 From: Ori Shalom Date: Sun, 21 May 2023 03:06:45 +0300 Subject: [PATCH] fix bugs and improve test infra --- go.mod | 2 +- router/binders.go | 4 - router/binders_test.go | 210 ++++-------------- router/content_types_test.go | 107 +++++++++ router/core.go | 75 ++++--- router/core_test.go | 1 - router/helpers_test.go | 2 +- router/runtime_errors_test.go | 58 ++++- .../array_schema_validator.go | 2 +- router/schema_validator/assetions.go | 16 +- .../object_schema_validator.go | 7 +- .../schema_format_validator.go | 3 + .../string_schema_validator.go | 12 +- .../string_schema_validator_test.go | 16 +- 14 files changed, 298 insertions(+), 217 deletions(-) diff --git a/go.mod b/go.mod index 134094b..0dae9db 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/piiano/cellotape -go 1.18 +go 1.20 retract v1.0.0 // Published accidentally. diff --git a/router/binders.go b/router/binders.go index 1d71b4f..a331d48 100644 --- a/router/binders.go +++ b/router/binders.go @@ -270,10 +270,6 @@ func requestValidationInput(ctx *Context) *openapi3filter.RequestValidationInput ParamDecoder: nil, } - input.Options.WithCustomSchemaErrorFunc(func(err *openapi3.SchemaError) string { - return err.Reason - }) - if ctx.Request != nil { input.Request = ctx.Request if ctx.Request.URL != nil { diff --git a/router/binders_test.go b/router/binders_test.go index 3219a62..f78866f 100644 --- a/router/binders_test.go +++ b/router/binders_test.go @@ -1,12 +1,10 @@ package router import ( - "bytes" "errors" "io" "net/http" "net/http/httptest" - "net/url" "reflect" "testing" @@ -37,16 +35,7 @@ type StructType struct { func TestQueryBinderFactory(t *testing.T) { queryBinder := queryBinderFactory[StructType](reflect.TypeOf(StructType{})) var params StructType - requestURL, err := url.Parse("http:0.0.0.0:90/abc?Foo=42") - require.NoError(t, err) - err = queryBinder(&Context{ - Request: &http.Request{ - URL: requestURL, - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶ms) + err := queryBinder(testContext(withURL(t, "http:0.0.0.0:90/abc?Foo=42")), ¶ms) require.NoError(t, err) assert.Equal(t, StructType{Foo: 42}, params) } @@ -58,16 +47,7 @@ type StructWithArrayType struct { func TestQueryBinderFactoryWithArrayType(t *testing.T) { queryBinder := queryBinderFactory[StructWithArrayType](reflect.TypeOf(StructWithArrayType{})) var params StructWithArrayType - requestURL, err := url.Parse("http:0.0.0.0:90/abc?Foo=42&Foo=6&Foo=7") - require.NoError(t, err) - err = queryBinder(&Context{ - Request: &http.Request{ - URL: requestURL, - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶ms) + err := queryBinder(testContext(withURL(t, "http:0.0.0.0:90/abc?Foo=42&Foo=6&Foo=7")), ¶ms) require.NoError(t, err) assert.Equal(t, StructWithArrayType{Foo: []int{42, 6, 7}}, params) } @@ -75,48 +55,24 @@ func TestQueryBinderFactoryWithArrayType(t *testing.T) { func TestQueryBinderFactoryMultipleParamToNonArrayError(t *testing.T) { queryBinder := queryBinderFactory[StructType](reflect.TypeOf(StructType{})) var params StructType - requestURL, err := url.Parse("http:0.0.0.0:90/abc?Foo=42&Foo=6&Foo=7") - require.NoError(t, err) - err = queryBinder(&Context{ - Request: &http.Request{ - URL: requestURL, - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶ms) + err := queryBinder(testContext(withURL(t, "http:0.0.0.0:90/abc?Foo=42&Foo=6&Foo=7")), ¶ms) require.Error(t, err) } func TestQueryBinderFactoryError(t *testing.T) { queryBinder := queryBinderFactory[StructType](reflect.TypeOf(StructType{})) var params StructType - requestURL, err := url.Parse("http:0.0.0.0:90/abc?Foo=abc") - require.NoError(t, err) - err = queryBinder(&Context{ - Request: &http.Request{ - URL: requestURL, - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶ms) - + err := queryBinder(testContext(withURL(t, "http:0.0.0.0:90/abc?Foo=abc")), ¶ms) require.Error(t, err) } func TestPathBinderFactory(t *testing.T) { pathBinder := pathBinderFactory[StructType](reflect.TypeOf(StructType{})) var params StructType - err := pathBinder(&Context{ - Params: &httprouter.Params{{ - Key: "Foo", - Value: "42", - }}, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶ms) + err := pathBinder(testContext(withParams(&httprouter.Params{{ + Key: "Foo", + Value: "42", + }})), ¶ms) require.NoError(t, err) assert.Equal(t, StructType{Foo: 42}, params) } @@ -124,49 +80,32 @@ func TestPathBinderFactory(t *testing.T) { func TestPathBinderFactoryError(t *testing.T) { pathBinder := pathBinderFactory[StructType](reflect.TypeOf(StructType{})) var params StructType - err := pathBinder(&Context{ - Params: &httprouter.Params{{ - Key: "Foo", - Value: "bar", - }}, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶ms) + err := pathBinder(testContext(withParams(&httprouter.Params{{ + Key: "Foo", + Value: "bar", + }})), ¶ms) require.Error(t, err) } func TestRequestBodyBinderFactory(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&Context{ - Request: &http.Request{ - Body: io.NopCloser(bytes.NewBuffer([]byte("42"))), - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶m) + err := requestBodyBinder(testContext(withBody("42")), ¶m) require.NoError(t, err) assert.Equal(t, 42, param) } func TestRequestBodyBinderFactoryWithSchema(t *testing.T) { - operation := openapi3.NewOperation() - operation.RequestBody = &openapi3.RequestBodyRef{ + testOp := openapi3.NewOperation() + testOp.RequestBody = &openapi3.RequestBodyRef{ Value: openapi3.NewRequestBody().WithJSONSchema(openapi3.NewIntegerSchema()), } requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&Context{ - Request: &http.Request{ - Header: map[string][]string{"Content-Type": {"application/json"}}, - Body: io.NopCloser(bytes.NewBuffer([]byte("42"))), - }, - Operation: SpecOperation{ - Operation: operation, - }, - }, ¶m) + err := requestBodyBinder(testContext( + withBody("42"), + withHeader("Content-Type", "application/json"), + withOperation(testOp)), ¶m) require.NoError(t, err) assert.Equal(t, 42, param) } @@ -174,14 +113,8 @@ func TestRequestBodyBinderFactoryWithSchema(t *testing.T) { func TestRequestBodyBinderFactoryError(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&Context{ - Request: &http.Request{ - Body: io.NopCloser(bytes.NewBuffer([]byte(`"foo"`))), - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶m) + + err := requestBodyBinder(testContext(withBody(`"foo"`)), ¶m) require.Error(t, err) } @@ -196,44 +129,27 @@ func (r readerWithError) Read(_ []byte) (int, error) { func TestRequestBodyBinderFactoryReaderError(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&Context{ - Request: &http.Request{ - Body: io.NopCloser(readerWithError(`42`)), - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶m) + err := requestBodyBinder(testContext( + withBodyReader(io.NopCloser(readerWithError(`42`)))), ¶m) require.Error(t, err) } func TestRequestBodyBinderFactoryContentTypeError(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&Context{ - Request: &http.Request{ - Header: http.Header{"Content-Type": {"no-such-content-type"}}, - Body: io.NopCloser(bytes.NewBuffer([]byte(`42`))), - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶m) + + err := requestBodyBinder(testContext( + withBody("42"), + withHeader("Content-Type", "no-such-content-type")), ¶m) require.Error(t, err) } func TestRequestBodyBinderFactoryContentTypeWithCharset(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&Context{ - Request: &http.Request{ - Header: http.Header{"Content-Type": {"application/json; charset=utf-8"}}, - Body: io.NopCloser(bytes.NewBuffer([]byte("42"))), - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶m) + err := requestBodyBinder(testContext( + withBody("42"), + withHeader("Content-Type", "application/json; charset=utf-8")), ¶m) require.NoError(t, err) assert.Equal(t, 42, param) } @@ -241,30 +157,18 @@ func TestRequestBodyBinderFactoryContentTypeWithCharset(t *testing.T) { func TestRequestBodyBinderFactoryInvalidContentType(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&Context{ - Request: &http.Request{ - Header: http.Header{"Content-Type": {"invalid content type"}}, - Body: io.NopCloser(bytes.NewBuffer([]byte("42"))), - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶m) + err := requestBodyBinder(testContext( + withBody("42"), + withHeader("Content-Type", "invalid content type")), ¶m) require.Error(t, err) } func TestRequestBodyBinderFactoryContentTypeAnyWithCharset(t *testing.T) { requestBodyBinder := requestBodyBinderFactory[int](reflect.TypeOf(0), DefaultContentTypes()) var param int - err := requestBodyBinder(&Context{ - Request: &http.Request{ - Header: http.Header{"Content-Type": {"*/*; charset=utf-8"}}, - Body: io.NopCloser(bytes.NewBuffer([]byte("42"))), - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶m) + err := requestBodyBinder(testContext( + withBody("42"), + withHeader("Content-Type", "*/*; charset=utf-8")), ¶m) require.NoError(t, err) assert.Equal(t, 42, param) } @@ -282,17 +186,11 @@ type CollidingFieldsParams struct { func TestBindingEmbeddedQueryParamsCollidingFields(t *testing.T) { requestBodyBinder := queryBinderFactory[CollidingFieldsParams](reflect.TypeOf(CollidingFieldsParams{})) - requestURL, err := url.Parse("http://http:0.0.0.0:8080/path?param1=foo¶m2=bar") - require.NoError(t, err) var param CollidingFieldsParams - err = requestBodyBinder(&Context{ - Request: &http.Request{ - URL: requestURL, - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶m) + + ctx := testContext(withURL(t, "http://http:0.0.0.0:8080/path?param1=foo¶m2=bar")) + + err := requestBodyBinder(ctx, ¶m) require.NoError(t, err) require.Equal(t, "foo", param.CollidingFieldsParam1.Value) require.Equal(t, "bar", param.CollidingFieldsParam2.Value) @@ -311,17 +209,10 @@ type CollidingParams struct { func TestBindingEmbeddedQueryParamsCollidingParams(t *testing.T) { requestBodyBinder := queryBinderFactory[CollidingParams](reflect.TypeOf(CollidingParams{})) - requestURL, err := url.Parse("http://http:0.0.0.0:8080/path?param1=42") - require.NoError(t, err) + var param CollidingParams - err = requestBodyBinder(&Context{ - Request: &http.Request{ - URL: requestURL, - }, - Operation: SpecOperation{ - Operation: openapi3.NewOperation(), - }, - }, ¶m) + err := requestBodyBinder(testContext( + withURL(t, "http://http:0.0.0.0:8080/path?param1=42")), ¶m) require.NoError(t, err) require.Equal(t, "42", param.CollidingParamString.Value) require.Equal(t, 42, param.CollidingParamInt.Value) @@ -362,16 +253,11 @@ func TestErrOnWriterError(t *testing.T) { for _, test := range testCases { t.Run(test.name, func(t *testing.T) { - _, err := binder(&Context{ - Operation: SpecOperation{ - Operation: testOp, - }, - Request: &http.Request{ - URL: &url.URL{}, - }, - Writer: test.writer, - RawResponse: &RawResponse{}, - }, response) + ctx := testContext( + withOperation(testOp), + withResponseWriter(test.writer), + ) + _, err := binder(ctx, response) test.assertion(t, err) }) } diff --git a/router/content_types_test.go b/router/content_types_test.go index 65632f8..76b9f9e 100644 --- a/router/content_types_test.go +++ b/router/content_types_test.go @@ -1,8 +1,15 @@ package router import ( + "bytes" + "errors" + "io" + "net/http" + "net/http/httptest" + "net/url" "reflect" "testing" + "testing/iotest" "github.com/getkin/kin-openapi/openapi3" "github.com/stretchr/testify/assert" @@ -67,6 +74,106 @@ func TestContentTypeMime(t *testing.T) { } } +type foo struct { + Foo string `json:"foo"` +} +type fooContentType struct { + shouldErr bool +} + +func (f fooContentType) Mime() string { return "foo" } + +func (f fooContentType) Encode(a any) ([]byte, error) { + return []byte(a.(foo).Foo), nil +} + +func (f fooContentType) Decode(bytes []byte, a any) error { + if f.shouldErr { + return errors.New("foo decode error") + } + switch typedValue := a.(type) { + case *foo: + (*typedValue).Foo = string(bytes) + case *any: + *typedValue = string(bytes) + } + return nil +} + +func (f fooContentType) ValidateTypeSchema(_ utils.Logger, _ utils.LogLevel, _ reflect.Type, _ openapi3.Schema) error { + return nil +} + +func TestValidationsWithCustomContentType(t *testing.T) { + testSpec, err := NewSpecFromData([]byte(` +paths: + /test: + post: + operationId: test + requestBody: + content: + foo: + schema: + type: string + responses: + '200': + description: ok +`)) + require.NoError(t, err) + + testCases := []struct { + contentType ContentType + bodyReader io.ReadCloser + shouldErr bool + }{ + { + contentType: fooContentType{}, + bodyReader: io.NopCloser(bytes.NewBufferString("bar")), + }, + { + contentType: fooContentType{shouldErr: true}, + bodyReader: io.NopCloser(bytes.NewBufferString("bar")), + shouldErr: true, + }, + { + contentType: fooContentType{}, + bodyReader: io.NopCloser(iotest.ErrReader(errors.New("failed reading body"))), + shouldErr: true, + }, + } + + for _, test := range testCases { + var calledWithBody *foo + var badRequestErr error + router := NewOpenAPIRouter(testSpec). + WithContentType(test.contentType). + WithOperation("test", HandlerFunc[foo, Nil, Nil, OKResponse[Nil]](func(_ *Context, r Request[foo, Nil, Nil]) (Response[OKResponse[Nil]], error) { + calledWithBody = &r.Body + return SendOK(OKResponse[Nil]{}), nil + }), ErrorHandler(func(_ *Context, err error) (Response[any], error) { + badRequestErr = err + return Response[any]{}, nil + })) + handler, err := router.AsHandler() + require.NoError(t, err) + + handler.ServeHTTP(&httptest.ResponseRecorder{}, &http.Request{ + Method: http.MethodPost, + URL: &url.URL{Path: "/test"}, + Header: http.Header{"Content-Type": []string{"foo"}}, + Body: test.bodyReader, + }) + + if test.shouldErr { + //require.Nil(t, calledWithBody) + require.Error(t, badRequestErr) + } else { + assert.Equal(t, foo{Foo: "bar"}, *calledWithBody) + require.NoError(t, badRequestErr) + } + } +} + func TestOctetStreamContentTypeBytesSlice(t *testing.T) { encodedBytes, err := OctetStreamContentType{}.Encode([]byte("foo")) require.NoError(t, err) diff --git a/router/core.go b/router/core.go index d57a715..40b9e80 100644 --- a/router/core.go +++ b/router/core.go @@ -1,6 +1,7 @@ package router import ( + "encoding/json" "errors" "io" "log" @@ -16,16 +17,17 @@ import ( ) func createMainRouterHandler(oa *openapi) (http.Handler, error) { + // Customize the error message returned by the kin-openapi library to be more user-friendly. + openapi3filter.DefaultOptions.WithCustomSchemaErrorFunc(func(err *openapi3.SchemaError) string { + return err.Reason + }) flatOperations := flattenOperations(oa.group) if err := validateOpenAPIRouter(oa, flatOperations); err != nil { return nil, err } router := httprouter.New() router.HandleMethodNotAllowed = false - ////router.PanicHandler = nil - //router.PanicHandler = func(writer http.ResponseWriter, request *http.Request, i interface{}) { - // log.Println("http-router handler") - //} + logger := oa.logger() pathParamsMatcher := regexp.MustCompile(`\{([^/}]*)}`) @@ -45,42 +47,46 @@ func createMainRouterHandler(oa *openapi) (http.Handler, error) { openapi3filter.RegisterBodyEncoder(contentType.Mime(), contentType.Encode) } if openapi3filter.RegisteredBodyDecoder(mimeType) == nil { - openapi3filter.RegisterBodyDecoder(contentType.Mime(), func(reader io.Reader, _ http.Header, schema *openapi3.SchemaRef, enc openapi3filter.EncodingFn) (any, error) { - //err := contentType.ValidateTypeSchema(oa.logger(), - // oa.options.LogLevel, - // utils.GetType[any](), - // *schema.Value) - // - //bytes, err := io.ReadAll(reader) - //if err != nil { - // return nil, err - //} - // - //var target any - //if err = contentType.Decode(bytes, &target); err != nil { - // return nil, err - //} - switch schema.Value.Type { - case openapi3.TypeArray: - return []any{}, nil - case openapi3.TypeObject: - return map[string]any{}, nil - case openapi3.TypeBoolean: - return false, nil - case openapi3.TypeString: - return "", nil - case openapi3.TypeNumber, openapi3.TypeInteger: - return 0, nil - } - return nil, nil - }) + openapi3filter.RegisterBodyDecoder(contentType.Mime(), createDecoder(contentType)) } - } return router, nil } +func createDecoder(contentType ContentType) func(reader io.Reader, _ http.Header, schema *openapi3.SchemaRef, enc openapi3filter.EncodingFn) (any, error) { + return func(reader io.Reader, _ http.Header, schema *openapi3.SchemaRef, enc openapi3filter.EncodingFn) (any, error) { + bytes, err := io.ReadAll(reader) + if err != nil { + return nil, err + } + + var target any + if err = contentType.Decode(bytes, &target); err != nil { + return nil, err + } + + // For kin-openapi to be able to validate a request it requires that the decoded value will on of + // the values received when decoding JSON to any. + // e.g. any, []any, []map[string]any, etc. + // + // After using the custom decoder we get a value of the type of the target struct. + // To overcome this we marshal the target to JSON and then unmarshal it to any. + + jsonBytes, err := json.Marshal(target) + if err != nil { + return nil, err + } + + var jsonValue any + if err = json.Unmarshal(jsonBytes, &jsonValue); err != nil { + return nil, err + } + + return jsonValue, nil + } +} + func (oa *openapi) logger() utils.Logger { return utils.NewLoggerWithLevel(oa.options.LogOutput, oa.options.LogLevel) } @@ -128,6 +134,7 @@ func asHttpRouterHandler(oa openapi, specOp SpecOperation, head BoundHandlerFunc Params: ¶ms, RawResponse: &RawResponse{Status: 0}, } + _, err := head(ctx) if err != nil || ctx.RawResponse.Status == 0 { writer.WriteHeader(500) diff --git a/router/core_test.go b/router/core_test.go index 66a3339..87564bd 100644 --- a/router/core_test.go +++ b/router/core_test.go @@ -32,7 +32,6 @@ func TestName(t *testing.T) { } func TestFailStartOnValidationError(t *testing.T) { - _, err := createMainRouterHandler(&openapi{ spec: NewSpec(), options: DefaultOptions(), diff --git a/router/helpers_test.go b/router/helpers_test.go index f9f9bf5..e69d7fe 100644 --- a/router/helpers_test.go +++ b/router/helpers_test.go @@ -127,7 +127,7 @@ func TestRawHandler(t *testing.T) { handlerFunc := rawHandler.handlerFactory(openapi{}, func(c *Context) (RawResponse, error) { return rawResponse, nil }) - resp, err := handlerFunc(&Context{Request: &http.Request{}, RawResponse: &RawResponse{}}) + resp, err := handlerFunc(testContext()) require.ErrorIs(t, err, UnsupportedResponseStatusErr) assert.Zero(t, resp) diff --git a/router/runtime_errors_test.go b/router/runtime_errors_test.go index 8dc0d80..3cecaca 100644 --- a/router/runtime_errors_test.go +++ b/router/runtime_errors_test.go @@ -1,7 +1,9 @@ package router import ( + "bytes" "errors" + "io" "net/http" "net/http/httptest" "net/url" @@ -16,8 +18,10 @@ import ( "github.com/piiano/cellotape/router/utils" ) -var testContext = func() *Context { - return &Context{ +type contextModifier func(*Context) + +func testContext(modifiers ...contextModifier) *Context { + ctx := &Context{ Operation: SpecOperation{ Operation: openapi3.NewOperation(), }, @@ -29,6 +33,56 @@ var testContext = func() *Context { Writer: &httptest.ResponseRecorder{}, Params: &httprouter.Params{}, } + for _, modifier := range modifiers { + modifier(ctx) + } + return ctx +} + +func withURL(t *testing.T, urlString string) contextModifier { + urlValue, err := url.Parse(urlString) + require.NoError(t, err) + return func(ctx *Context) { + ctx.Request.URL = urlValue + } +} + +func withBody(body string) contextModifier { + bodyReader := io.NopCloser(bytes.NewBuffer([]byte(body))) + + return withBodyReader(bodyReader) +} + +func withBodyReader(bodyReader io.ReadCloser) contextModifier { + return func(ctx *Context) { + ctx.Request.Body = bodyReader + } +} + +func withParams(params *httprouter.Params) contextModifier { + return func(ctx *Context) { + ctx.Params = params + } +} + +func withHeader(header string, values ...string) contextModifier { + return func(ctx *Context) { + for _, value := range values { + ctx.Request.Header.Add(header, value) + } + } +} + +func withOperation(operation *openapi3.Operation) contextModifier { + return func(ctx *Context) { + ctx.Operation.Operation = operation + } +} + +func withResponseWriter(writer http.ResponseWriter) contextModifier { + return func(ctx *Context) { + ctx.Writer = writer + } } func TestNewBadRequestErr(t *testing.T) { diff --git a/router/schema_validator/array_schema_validator.go b/router/schema_validator/array_schema_validator.go index d3429f5..3c76c8a 100644 --- a/router/schema_validator/array_schema_validator.go +++ b/router/schema_validator/array_schema_validator.go @@ -10,7 +10,7 @@ func (c typeSchemaValidatorContext) validateArraySchema() { } if !isSchemaTypeArrayOrEmpty(c.schema) { - if isArrayGoType(c.goType) { + if isArrayGoType(c.goType) && !isSliceOfBytes(c.goType) { c.err(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) } return diff --git a/router/schema_validator/assetions.go b/router/schema_validator/assetions.go index 10d7330..922055f 100644 --- a/router/schema_validator/assetions.go +++ b/router/schema_validator/assetions.go @@ -11,16 +11,18 @@ import ( ) var ( - timeType = utils.GetType[time.Time]() - uuidType = utils.GetType[uuid.UUID]() + timeType = utils.GetType[time.Time]() + uuidType = utils.GetType[uuid.UUID]() + sliceOfBytesType = utils.GetType[[]byte]() isString = kindIs(reflect.String) isUUIDCompatible = anyOf(isString, convertibleTo(uuidType)) + isSliceOfBytes = anyOf(isString, typeIs(sliceOfBytesType)) isTimeCompatible = anyOf(isString, convertibleTo(timeType)) - isSerializedFromString = anyOf(isString, isUUIDCompatible, isTimeCompatible) + isSerializedFromString = anyOf(isString, isUUIDCompatible, isTimeCompatible, isSliceOfBytes) isTimeFormat = schemaFormatIs(dateTimeFormat, timeFormat) - isSchemaStringFormat = schemaFormatIs(uuidFormat, dateTimeFormat, timeFormat, dateFormat, durationFormat, + isSchemaStringFormat = schemaFormatIs(uuidFormat, byteFormat, dateTimeFormat, timeFormat, dateFormat, durationFormat, emailFormat, idnEmailFormat, hostnameFormat, idnHostnameFormat, ipv4Format, ipv6Format, uriFormat, uriReferenceFormat, iriFormat, iriReferenceFormat, uriTemplateFormat, jsonPointerFormat, relativeJsonPointerFormat, regexFormat, passwordFormat) @@ -98,6 +100,12 @@ func kindIs(kinds ...reflect.Kind) typeAssertion { }) } +func typeIs(types ...reflect.Type) typeAssertion { + return handleMultiType(func(t reflect.Type) bool { + return utils.NewSet(types...).Has(t) + }) +} + func convertibleTo(targets ...reflect.Type) typeAssertion { return handleMultiType(func(t reflect.Type) bool { for _, target := range targets { diff --git a/router/schema_validator/object_schema_validator.go b/router/schema_validator/object_schema_validator.go index 609eaf4..431f0aa 100644 --- a/router/schema_validator/object_schema_validator.go +++ b/router/schema_validator/object_schema_validator.go @@ -34,8 +34,11 @@ func (c typeSchemaValidatorContext) validateObjectSchema() { return c.assertStruct(t) } - // kind must be struct or object because we validated above serializedFromObject - return c.assertMap(t) + if t.Kind() == reflect.Map { + return c.assertMap(t) + } + + return false })(c.goType) } diff --git a/router/schema_validator/schema_format_validator.go b/router/schema_validator/schema_format_validator.go index 5683fd5..c4573d6 100644 --- a/router/schema_validator/schema_format_validator.go +++ b/router/schema_validator/schema_format_validator.go @@ -47,6 +47,9 @@ const ( // A string instance is valid against this attribute if it is a valid string representation of a UUID, according to [RFC4122]. uuidFormat = "uuid" // use with openapi3.TypeString + // A string instance is valid against this attribute if it is a valid base64 string. + byteFormat = "byte" + // use with openapi3.TypeString // This attribute applies to string instances. // A string instance is valid against this attribute if it is a valid URI Template (of any level), according to [RFC6570]. // Note that URI Templates may be used for IRIs; there is no separate IRI Template specification. diff --git a/router/schema_validator/string_schema_validator.go b/router/schema_validator/string_schema_validator.go index 0487f9f..5999305 100644 --- a/router/schema_validator/string_schema_validator.go +++ b/router/schema_validator/string_schema_validator.go @@ -25,6 +25,14 @@ func (c typeSchemaValidatorContext) validateStringSchema() { return } + // if schema format is "byte" expect type to be compatible with []byte + if c.schema.Format == byteFormat { + if (c.schema.Type == openapi3.TypeString || isSerializedFromString(c.goType)) && !isSliceOfBytes(c.goType) { + c.err(schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) + } + return + } + // if schema format is "uuid" expect type to be compatible with UUID if c.schema.Format == uuidFormat { if (c.schema.Type == openapi3.TypeString || isSerializedFromString(c.goType)) && !isUUIDCompatible(c.goType) { @@ -46,8 +54,4 @@ func (c typeSchemaValidatorContext) validateStringSchema() { c.err(schemaTypeWithFormatIsIncompatibleWithType(c.schema, c.goType)) return } - - //if isString(c.goType) { - // c.err(schemaTypeIsIncompatibleWithType(c.schema, c.goType)) - //} } diff --git a/router/schema_validator/string_schema_validator_test.go b/router/schema_validator/string_schema_validator_test.go index 08a6cfc..b1b181b 100644 --- a/router/schema_validator/string_schema_validator_test.go +++ b/router/schema_validator/string_schema_validator_test.go @@ -12,6 +12,16 @@ import ( "github.com/piiano/cellotape/router/utils" ) +func TestStringSchemaValidatorWithByteFormat(t *testing.T) { + stringSchema := openapi3.NewStringSchema() + stringSchema.Format = "byte" + + validator := schemaValidator(*stringSchema) + errTemplate := "expect string schema to be compatible with %s type" + bytes := reflect.TypeOf([]byte{}) + expectTypeToBeCompatible(t, validator, bytes, errTemplate, bytes) +} + func TestStringSchemaValidatorPassForStringType(t *testing.T) { stringSchema := openapi3.NewStringSchema() validator := schemaValidator(*stringSchema) @@ -23,7 +33,11 @@ func TestStringSchemaValidatorPassForStringType(t *testing.T) { func TestStringSchemaValidatorWithUntypedSchema(t *testing.T) { untypedSchemaWithUUIDFormat := openapi3.NewSchema().WithFormat(uuidFormat) - for _, validType := range types { + otherNonStringTypes := utils.Filter(types, func(t reflect.Type) bool { + return t != sliceOfBytesType && t != timeType + }) + + for _, validType := range otherNonStringTypes { t.Run(validType.String(), func(t *testing.T) { err := schemaValidator(*untypedSchemaWithUUIDFormat).WithType(validType).Validate() require.NoErrorf(t, err, "expect untyped schema to be compatible with %s type", validType)