diff --git a/README.md b/README.md index 09419ff7..77ab2d00 100644 --- a/README.md +++ b/README.md @@ -217,6 +217,7 @@ Conditional helpers: Type manipulation helpers: - [ToPtr](#toptr) +- [EmptyableToPtr](#emptyabletoptr) - [FromPtr](#fromptr) - [FromPtrOr](#fromptror) - [ToSlicePtr](#tosliceptr) @@ -2163,6 +2164,25 @@ ptr := lo.ToPtr("hello world") // *string{"hello world"} ``` +### EmptyableToPtr + +Returns a pointer copy of value if it's nonzero. +Otherwise, returns nil pointer. + +```go +ptr := lo.EmptyableToPtr[[]int](nil) +// nil + +ptr := lo.EmptyableToPtr[string]("") +// nil + +ptr := lo.EmptyableToPtr[[]int]([]int{}) +// *[]int{} + +ptr := lo.EmptyableToPtr[string]("hello world") +// *string{"hello world"} +``` + ### FromPtr Returns the pointer value or empty. diff --git a/type_manipulation.go b/type_manipulation.go index fe99ee1f..c392df37 100644 --- a/type_manipulation.go +++ b/type_manipulation.go @@ -1,10 +1,23 @@ package lo +import "reflect" + // ToPtr returns a pointer copy of value. func ToPtr[T any](x T) *T { return &x } +// EmptyableToPtr returns a pointer copy of value if it's nonzero. +// Otherwise, returns nil pointer. +func EmptyableToPtr[T any](x T) *T { + isZero := reflect.ValueOf(&x).Elem().IsZero() + if isZero { + return nil + } + + return &x +} + // FromPtr returns the pointer value or empty. func FromPtr[T any](x *T) T { if x == nil { diff --git a/type_manipulation_test.go b/type_manipulation_test.go index 807c8695..48e3676f 100644 --- a/type_manipulation_test.go +++ b/type_manipulation_test.go @@ -15,6 +15,24 @@ func TestToPtr(t *testing.T) { is.Equal(*result1, []int{1, 2}) } +func TestEmptyableToPtr(t *testing.T) { + t.Parallel() + is := assert.New(t) + + is.Nil(EmptyableToPtr(0)) + is.Nil(EmptyableToPtr("")) + is.Nil(EmptyableToPtr[[]int](nil)) + is.Nil(EmptyableToPtr[map[int]int](nil)) + is.Nil(EmptyableToPtr[error](nil)) + + is.Equal(*EmptyableToPtr(42), 42) + is.Equal(*EmptyableToPtr("nonempty"), "nonempty") + is.Equal(*EmptyableToPtr([]int{}), []int{}) + is.Equal(*EmptyableToPtr([]int{1, 2}), []int{1, 2}) + is.Equal(*EmptyableToPtr(map[int]int{}), map[int]int{}) + is.Equal(*EmptyableToPtr(assert.AnError), assert.AnError) +} + func TestFromPtr(t *testing.T) { t.Parallel() is := assert.New(t)