From 32ec92a20ca1927cc6a611df19b0c2a5eeb83d93 Mon Sep 17 00:00:00 2001 From: Geoffrey Ragot Date: Mon, 23 Sep 2024 16:47:49 +0200 Subject: [PATCH] fix(bunpaginate): time columns --- bun/bunpaginate/pagination_column.go | 32 ++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/bun/bunpaginate/pagination_column.go b/bun/bunpaginate/pagination_column.go index 680b6af..619cefe 100644 --- a/bun/bunpaginate/pagination_column.go +++ b/bun/bunpaginate/pagination_column.go @@ -114,15 +114,23 @@ func findPaginationFieldPath(v any, paginationColumn string) []reflect.StructFie field := typeOfT.Field(i) switch field.Type.Kind() { case reflect.Struct: - fields := findPaginationFieldPath(reflect.New(field.Type).Elem().Interface(), paginationColumn) - if len(fields) > 0 { - return fields + if field.Type.AssignableTo(reflect.TypeOf(time.Time{})) || + field.Type.AssignableTo(reflect.TypeOf(&time.Time{})) || + field.Type.AssignableTo(reflect.TypeOf(libtime.Time{})) || + field.Type.AssignableTo(reflect.TypeOf(&libtime.Time{})) { + + if fields := checkTag(field, paginationColumn); len(fields) > 0 { + return fields + } + } else { + fields := findPaginationFieldPath(reflect.New(field.Type).Elem().Interface(), paginationColumn) + if len(fields) > 0 { + return fields + } } default: - tag := field.Tag.Get("bun") - column := strings.Split(tag, ",")[0] - if column == paginationColumn { - return []reflect.StructField{field} + if fields := checkTag(field, paginationColumn); len(fields) > 0 { + return fields } } } @@ -130,6 +138,16 @@ func findPaginationFieldPath(v any, paginationColumn string) []reflect.StructFie return nil } +func checkTag(field reflect.StructField, paginationColumn string) []reflect.StructField { + tag := field.Tag.Get("bun") + column := strings.Split(tag, ",")[0] + if column == paginationColumn { + return []reflect.StructField{field} + } + + return nil +} + func findPaginationField(v any, fields ...reflect.StructField) *big.Int { vOf := reflect.ValueOf(v) field := vOf.FieldByName(fields[0].Name)