From d5b32c6e47569961c9b03fe96ca36be7c9d0a02b Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 20 Oct 2024 20:14:34 +0300 Subject: [PATCH] Refactor work done by martinpasaribu (binding multipart files by using struct tags) --- bind.go | 117 +++++++++++++--------- bind_test.go | 277 ++++++++++++++++++++++++--------------------------- 2 files changed, 200 insertions(+), 194 deletions(-) diff --git a/bind.go b/bind.go index 157be09fc..ed7ca3249 100644 --- a/bind.go +++ b/bind.go @@ -46,7 +46,7 @@ func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { for i, name := range names { params[name] = []string{values[i]} } - if err := b.bindData(i, params, "param"); err != nil { + if err := b.bindData(i, params, "param", nil); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil @@ -54,7 +54,7 @@ func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { // BindQueryParams binds query params to bindable object func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { - if err := b.bindData(i, c.QueryParams(), "query"); err != nil { + if err := b.bindData(i, c.QueryParams(), "query", nil); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil @@ -71,9 +71,12 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { return } - ctype := req.Header.Get(HeaderContentType) - switch { - case strings.HasPrefix(ctype, MIMEApplicationJSON): + // mediatype is found like `mime.ParseMediaType()` does it + base, _, _ := strings.Cut(req.Header.Get(HeaderContentType), ";") + mediatype := strings.TrimSpace(base) + + switch mediatype { + case MIMEApplicationJSON: if err = c.Echo().JSONSerializer.Deserialize(c, i); err != nil { switch err.(type) { case *HTTPError: @@ -82,7 +85,7 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } } - case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML): + case MIMEApplicationXML, MIMETextXML: if err = xml.NewDecoder(req.Body).Decode(i); err != nil { if ute, ok := err.(*xml.UnsupportedTypeError); ok { return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err) @@ -91,15 +94,15 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { } return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } - case strings.HasPrefix(ctype, MIMEApplicationForm): + case MIMEApplicationForm: params, err := c.FormParams() if err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } - if err = b.bindData(i, params, "form"); err != nil { + if err = b.bindData(i, params, "form", nil); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } - case strings.HasPrefix(ctype, MIMEMultipartForm): + case MIMEMultipartForm: params, err := c.MultipartForm() if err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) @@ -115,7 +118,7 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { // BindHeaders binds HTTP headers to a bindable object func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error { - if err := b.bindData(i, c.Request().Header, "header"); err != nil { + if err := b.bindData(i, c.Request().Header, "header", nil); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil @@ -141,10 +144,11 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { } // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag -func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string, files ...map[string][]*multipart.FileHeader) error { - if destination == nil || (len(data) == 0 && len(files) == 0) { +func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error { + if destination == nil || (len(data) == 0 && len(dataFiles) == 0) { return nil } + hasFiles := len(dataFiles) > 0 typ := reflect.TypeOf(destination).Elem() val := reflect.ValueOf(destination).Elem() @@ -188,7 +192,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri return errors.New("binding element must be a struct") } - for i := 0; i < typ.NumField(); i++ { + for i := 0; i < typ.NumField(); i++ { // iterate over all destination fields typeField := typ.Field(i) structField := val.Field(i) if typeField.Anonymous { @@ -207,10 +211,10 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri } if inputFieldName == "" { - // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). + // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags). // structs that implement BindUnmarshaler are bound only when they have explicit tag if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { - if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { + if err := b.bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil { return err } } @@ -218,37 +222,20 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri continue } - // Handle multiple file uploads ([]*multipart.FileHeader, *multipart.FileHeader, []multipart.FileHeader) - if len(files) > 0 && isMultipartFile(structField.Type()) { - for _, fileMap := range files { - fileHeaders, exists := fileMap[inputFieldName] - if exists && len(fileHeaders) > 0 { - switch structField.Type() { - case reflect.TypeOf([]*multipart.FileHeader(nil)): - structField.Set(reflect.ValueOf(fileHeaders)) - continue - case reflect.TypeOf([]multipart.FileHeader(nil)): - headers := make([]multipart.FileHeader, len(fileHeaders)) - for i, fileHeader := range fileHeaders { - headers[i] = *fileHeader - } - structField.Set(reflect.ValueOf(headers)) - continue - case reflect.TypeOf(&multipart.FileHeader{}): - structField.Set(reflect.ValueOf(fileHeaders[0])) - continue - case reflect.TypeOf(multipart.FileHeader{}): - structField.Set(reflect.ValueOf(*fileHeaders[0])) - continue - } + if hasFiles { + if ok, err := isFieldMultipartFile(structField.Type()); err != nil { + return err + } else if ok { + if ok := setMultipartFileHeaderTypes(structField, inputFieldName, dataFiles); ok { + continue } } } inputValue, exists := data[inputFieldName] if !exists { - // Go json.Unmarshal supports case insensitive binding. However the - // url params are bound case sensitive which is inconsistent. To + // Go json.Unmarshal supports case-insensitive binding. However the + // url params are bound case-sensitive which is inconsistent. To // fix this we must check all of the map values in a // case-insensitive search. for k, v := range data { @@ -431,9 +418,49 @@ func setFloatField(value string, bitSize int, field reflect.Value) error { return err } -func isMultipartFile(field reflect.Type) bool { - return reflect.TypeOf(&multipart.FileHeader{}) == field || - reflect.TypeOf(multipart.FileHeader{}) == field || - reflect.TypeOf([]*multipart.FileHeader(nil)) == field || - reflect.TypeOf([]multipart.FileHeader(nil)) == field +var ( + // NOT supported by bind as you can NOT check easily empty struct being actual file or not + multipartFileHeaderType = reflect.TypeOf(multipart.FileHeader{}) + // supported by bind as you can check by nil value if file existed or not + multipartFileHeaderPointerType = reflect.TypeOf(&multipart.FileHeader{}) + multipartFileHeaderSliceType = reflect.TypeOf([]multipart.FileHeader(nil)) + multipartFileHeaderPointerSliceType = reflect.TypeOf([]*multipart.FileHeader(nil)) +) + +func isFieldMultipartFile(field reflect.Type) (bool, error) { + switch field { + case multipartFileHeaderPointerType, + multipartFileHeaderSliceType, + multipartFileHeaderPointerSliceType: + return true, nil + case multipartFileHeaderType: + return true, errors.New("binding to multipart.FileHeader struct is not supported, use pointer to struct") + default: + return false, nil + } +} + +func setMultipartFileHeaderTypes(structField reflect.Value, inputFieldName string, files map[string][]*multipart.FileHeader) bool { + fileHeaders := files[inputFieldName] + if len(fileHeaders) == 0 { + return false + } + + result := true + switch structField.Type() { + case multipartFileHeaderPointerSliceType: + structField.Set(reflect.ValueOf(fileHeaders)) + case multipartFileHeaderSliceType: + headers := make([]multipart.FileHeader, len(fileHeaders)) + for i, fileHeader := range fileHeaders { + headers[i] = *fileHeader + } + structField.Set(reflect.ValueOf(headers)) + case multipartFileHeaderPointerType: + structField.Set(reflect.ValueOf(fileHeaders[0])) + default: + result = false + } + + return result } diff --git a/bind_test.go b/bind_test.go index 2323a219b..c79669c8c 100644 --- a/bind_test.go +++ b/bind_test.go @@ -446,7 +446,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]string", func(t *testing.T) { dest := map[string]string{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]string{ "multiple": "1", @@ -458,7 +458,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]string with nil map", func(t *testing.T) { var dest map[string]string - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]string{ "multiple": "1", @@ -470,7 +470,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string][]string", func(t *testing.T) { dest := map[string][]string{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]string{ "multiple": {"1", "2"}, @@ -482,7 +482,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string][]string with nil map", func(t *testing.T) { var dest map[string][]string - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]string{ "multiple": {"1", "2"}, @@ -494,7 +494,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]interface", func(t *testing.T) { dest := map[string]interface{}{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]interface{}{ "multiple": "1", @@ -506,7 +506,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]interface with nil map", func(t *testing.T) { var dest map[string]interface{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]interface{}{ "multiple": "1", @@ -518,25 +518,25 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]int skips", func(t *testing.T) { dest := map[string]int{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]int{}, dest) }) t.Run("ok, bind to map[string]int skips with nil map", func(t *testing.T) { var dest map[string]int - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]int(nil), dest) }) t.Run("ok, bind to map[string][]int skips", func(t *testing.T) { dest := map[string][]int{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]int{}, dest) }) t.Run("ok, bind to map[string][]int skips with nil map", func(t *testing.T) { var dest map[string][]int - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]int(nil), dest) }) } @@ -544,7 +544,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { func TestBindbindData(t *testing.T) { ts := new(bindTestStruct) b := new(DefaultBinder) - err := b.bindData(ts, values, "form") + err := b.bindData(ts, values, "form", nil) assert.NoError(t, err) assert.Equal(t, 0, ts.I) @@ -666,7 +666,7 @@ func BenchmarkBindbindDataWithTags(b *testing.B) { var err error b.ResetTimer() for i := 0; i < b.N; i++ { - err = binder.bindData(ts, values, "form") + err = binder.bindData(ts, values, "form", nil) } assert.NoError(b, err) assertBindTestStruct(b, (*bindTestStruct)(ts)) @@ -1102,143 +1102,6 @@ func TestDefaultBinder_BindBody(t *testing.T) { } } -type testFile struct { - Fieldname string - Filename string - Content []byte -} - -// createRequestMultipartFiles creates a multipart HTTP request with multiple files. -func createRequestMultipartFiles(t *testing.T, files ...testFile) *http.Request { - var body bytes.Buffer - mw := multipart.NewWriter(&body) - - for _, file := range files { - fw, err := mw.CreateFormFile(file.Fieldname, file.Filename) - assert.NoError(t, err) - - n, err := fw.Write(file.Content) - assert.NoError(t, err) - assert.Equal(t, len(file.Content), n) - } - - err := mw.Close() - assert.NoError(t, err) - - req, err := http.NewRequest(http.MethodPost, "/", &body) - assert.NoError(t, err) - - req.Header.Set("Content-Type", mw.FormDataContentType()) - - return req -} - -func assertMultipartFileHeader(t *testing.T, fh *multipart.FileHeader, file testFile) { - assert.Equal(t, file.Filename, fh.Filename) - assert.Equal(t, int64(len(file.Content)), fh.Size) - fl, err := fh.Open() - assert.NoError(t, err) - body, err := io.ReadAll(fl) - assert.NoError(t, err) - assert.Equal(t, string(file.Content), string(body)) - err = fl.Close() - assert.NoError(t, err) -} - -func TestFormMultipartBindTwoFiles(t *testing.T) { - var args struct { - Files []*multipart.FileHeader `form:"files"` - } - - files := []testFile{ - { - Fieldname: "files", - Filename: "file1.txt", - Content: []byte("This is the content of file 1."), - }, - { - Fieldname: "files", - Filename: "file2.txt", - Content: []byte("This is the content of file 2."), - }, - } - - e := New() - req := createRequestMultipartFiles(t, files...) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := c.Bind(&args) - assert.NoError(t, err) - - assert.Len(t, args.Files, len(files)) - for idx, file := range files { - assertMultipartFileHeader(t, args.Files[idx], file) - } -} - -func TestFormMultipartBindMultipleKeys(t *testing.T) { - var args struct { - Files []multipart.FileHeader `form:"files"` - File multipart.FileHeader `form:"file"` - } - - files := []testFile{ - { - Fieldname: "files", - Filename: "file1.txt", - Content: []byte("This is the content of file 1."), - }, - { - Fieldname: "files", - Filename: "file2.txt", - Content: []byte("This is the content of file 2."), - }, - } - file := testFile{ - Fieldname: "file", - Filename: "file3.txt", - Content: []byte("This is the content of file 3."), - } - - e := New() - req := createRequestMultipartFiles(t, append(files, file)...) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := c.Bind(&args) - assert.NoError(t, err) - - assert.Len(t, args.Files, len(files)) - for idx, file := range files { - argsFile := args.Files[idx] - assertMultipartFileHeader(t, &argsFile, file) - } - assertMultipartFileHeader(t, &args.File, file) -} - -func TestFormMultipartBindOneFile(t *testing.T) { - var args struct { - File *multipart.FileHeader `form:"file"` - } - - file := testFile{ - Fieldname: "file", - Filename: "file1.txt", - Content: []byte("This is the content of file 1."), - } - - e := New() - req := createRequestMultipartFiles(t, file) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := c.Bind(&args) - assert.NoError(t, err) - - assertMultipartFileHeader(t, args.File, file) -} - func testBindURL(queryString string, target any) error { e := New() req := httptest.NewRequest(http.MethodGet, queryString, nil) @@ -1557,3 +1420,119 @@ func TestBindInt8(t *testing.T) { assert.Equal(t, target{V: &[]int8{1, 2}}, p) }) } + +func TestBindMultipartFormFiles(t *testing.T) { + file1 := createTestFormFile("file", "file1.txt") + file11 := createTestFormFile("file", "file11.txt") + file2 := createTestFormFile("file2", "file2.txt") + filesA := createTestFormFile("files", "filesA.txt") + filesB := createTestFormFile("files", "filesB.txt") + + t.Run("nok, can not bind to multipart file struct", func(t *testing.T) { + var target struct { + File multipart.FileHeader `form:"file"` + } + err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored + + assert.EqualError(t, err, "code=400, message=binding to multipart.FileHeader struct is not supported, use pointer to struct, internal=binding to multipart.FileHeader struct is not supported, use pointer to struct") + }) + + t.Run("ok, bind single multipart file to pointer to multipart file", func(t *testing.T) { + var target struct { + File *multipart.FileHeader `form:"file"` + } + err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored + + assert.NoError(t, err) + assertMultipartFileHeader(t, target.File, file1) + }) + + t.Run("ok, bind multiple multipart files to pointer to multipart file", func(t *testing.T) { + var target struct { + File *multipart.FileHeader `form:"file"` + } + err := bindMultipartFiles(t, &target, file1, file11) + + assert.NoError(t, err) + assertMultipartFileHeader(t, target.File, file1) // should choose first one + }) + + t.Run("ok, bind multiple multipart files to slice of multipart file", func(t *testing.T) { + var target struct { + Files []multipart.FileHeader `form:"files"` + } + err := bindMultipartFiles(t, &target, filesA, filesB, file1) + + assert.NoError(t, err) + + assert.Len(t, target.Files, 2) + assertMultipartFileHeader(t, &target.Files[0], filesA) + assertMultipartFileHeader(t, &target.Files[1], filesB) + }) + + t.Run("ok, bind multiple multipart files to slice of pointer to multipart file", func(t *testing.T) { + var target struct { + Files []*multipart.FileHeader `form:"files"` + } + err := bindMultipartFiles(t, &target, filesA, filesB, file1) + + assert.NoError(t, err) + + assert.Len(t, target.Files, 2) + assertMultipartFileHeader(t, target.Files[0], filesA) + assertMultipartFileHeader(t, target.Files[1], filesB) + }) +} + +type testFormFile struct { + Fieldname string + Filename string + Content []byte +} + +func createTestFormFile(formFieldName string, filename string) testFormFile { + return testFormFile{ + Fieldname: formFieldName, + Filename: filename, + Content: []byte(strings.Repeat(filename, 10)), + } +} + +func bindMultipartFiles(t *testing.T, target any, files ...testFormFile) error { + var body bytes.Buffer + mw := multipart.NewWriter(&body) + + for _, file := range files { + fw, err := mw.CreateFormFile(file.Fieldname, file.Filename) + assert.NoError(t, err) + + n, err := fw.Write(file.Content) + assert.NoError(t, err) + assert.Equal(t, len(file.Content), n) + } + + err := mw.Close() + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/", &body) + assert.NoError(t, err) + req.Header.Set("Content-Type", mw.FormDataContentType()) + + rec := httptest.NewRecorder() + + e := New() + c := e.NewContext(req, rec) + return c.Bind(target) +} + +func assertMultipartFileHeader(t *testing.T, fh *multipart.FileHeader, file testFormFile) { + assert.Equal(t, file.Filename, fh.Filename) + assert.Equal(t, int64(len(file.Content)), fh.Size) + fl, err := fh.Open() + assert.NoError(t, err) + body, err := io.ReadAll(fl) + assert.NoError(t, err) + assert.Equal(t, string(file.Content), string(body)) + err = fl.Close() + assert.NoError(t, err) +}