From d1d4e26a7b0226cae92cc9a359f4737ebac52ca0 Mon Sep 17 00:00:00 2001 From: Henning Rogge Date: Fri, 17 Feb 2023 08:13:47 +0100 Subject: [PATCH] Allow defaults library to use UnmarshalText() and UnmarshalJSON() interface to initialize values --- defaults.go | 23 +++++++++++++++++++++++ defaults_test.go | 43 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/defaults.go b/defaults.go index 3bb2c22..b5e7eb9 100644 --- a/defaults.go +++ b/defaults.go @@ -1,6 +1,7 @@ package defaults import ( + "encoding" "encoding/json" "errors" "reflect" @@ -61,6 +62,10 @@ func setField(field reflect.Value, defaultVal string) error { isInitial := isInitialValue(field) if isInitial { + if unmarshalByInterface(field, defaultVal) { + return nil + } + switch field.Kind() { case reflect.Bool: if val, err := strconv.ParseBool(defaultVal); err == nil { @@ -194,6 +199,24 @@ func setField(field reflect.Value, defaultVal string) error { return nil } +func unmarshalByInterface(field reflect.Value, defaultVal string) bool { + asText, ok := field.Addr().Interface().(encoding.TextUnmarshaler) + if ok && defaultVal != "" { + // if field implements encode.TextUnmarshaler, try to use it before decode by kind + if err := asText.UnmarshalText([]byte(defaultVal)); err == nil { + return true + } + } + asJSON, ok := field.Addr().Interface().(json.Unmarshaler) + if ok && defaultVal != "" && defaultVal != "{}" && defaultVal != "[]" { + // if field implements json.Unmarshaler, try to use it before decode by kind + if err := asJSON.UnmarshalJSON([]byte(defaultVal)); err == nil { + return true + } + } + return false +} + func isInitialValue(field reflect.Value) bool { return reflect.DeepEqual(reflect.Zero(field.Type()).Interface(), field.Interface()) } diff --git a/defaults_test.go b/defaults_test.go index 5bfba7b..350ad62 100644 --- a/defaults_test.go +++ b/defaults_test.go @@ -1,7 +1,11 @@ package defaults import ( + "encoding/json" + "errors" + "net" "reflect" + "strconv" "testing" "time" @@ -112,9 +116,12 @@ type Sample struct { MyMap MyMap `default:"{}"` MySlice MySlice `default:"[]"` - StructWithJSON Struct `default:"{\"Foo\": 123}"` - StructPtrWithJSON *Struct `default:"{\"Foo\": 123}"` - MapWithJSON map[string]int `default:"{\"foo\": 123}"` + StructWithText net.IP `default:"10.0.0.1"` + StructPtrWithText *net.IP `default:"10.0.0.1"` + StructWithJSON Struct `default:"{\"Foo\": 123}"` + StructPtrWithJSON *Struct `default:"{\"Foo\": 123}"` + MapWithJSON map[string]int `default:"{\"foo\": 123}"` + TypeWithUnmarshalJSON JSONOnlyType `default:"\"one\""` MapOfPtrStruct map[string]*Struct MapOfStruct map[string]Struct @@ -155,6 +162,24 @@ type Embedded struct { Int int `default:"1"` } +type JSONOnlyType int + +func (j *JSONOnlyType) UnmarshalJSON(b []byte) error { + var tmp string + if err := json.Unmarshal(b, &tmp); err != nil { + return err + } + if i, err := strconv.Atoi(tmp); err == nil { + *j = JSONOnlyType(i) + return nil + } + if tmp == "one" { + *j = 1 + return nil + } + return errors.New("cannot unmarshal") +} + func TestMustSet(t *testing.T) { t.Run("right way", func(t *testing.T) { @@ -485,6 +510,14 @@ func TestInit(t *testing.T) { } }) + t.Run("complex types with text unmarshal", func(t *testing.T) { + if !sample.StructWithText.Equal(net.ParseIP("10.0.0.1")) { + t.Errorf("it should initialize struct with text") + } + if !sample.StructPtrWithText.Equal(net.ParseIP("10.0.0.1")) { + t.Errorf("it should initialize struct with text") + } + }) t.Run("complex types with json", func(t *testing.T) { if sample.StructWithJSON.Foo != 123 { t.Errorf("it should initialize struct with json") @@ -499,6 +532,10 @@ func TestInit(t *testing.T) { t.Errorf("it should initialize slice with json") } + if int(sample.TypeWithUnmarshalJSON) != 1 { + t.Errorf("it should initialize json unmarshaled value") + } + t.Run("invalid json", func(t *testing.T) { if err := Set(&struct { I []int `default:"[!]"`