diff --git a/types.go b/types.go index cf6e947..14b9f7c 100644 --- a/types.go +++ b/types.go @@ -232,3 +232,55 @@ func (d Duration) String() string { var _ encoding.TextUnmarshaler = (*Duration)(nil) var _ encoding.TextMarshaler = (*Duration)(nil) + +// OptionalInteger represents an integer that has a default value +// +// When encoded in json, Default is encoded as "null" +type OptionalInteger struct { + value *int64 +} + +// WithDefault resolves the integer with the given default. +func (p OptionalInteger) WithDefault(defaultValue int64) (value int64) { + if p.value == nil { + return defaultValue + } + return *p.value +} + +// IsDefault returns if this is a default optional integer +func (p OptionalInteger) IsDefault() bool { + return p.value == nil +} + +func (p OptionalInteger) MarshalJSON() ([]byte, error) { + if p.value != nil { + return json.Marshal(p.value) + } + return json.Marshal(nil) +} + +func (p *OptionalInteger) UnmarshalJSON(input []byte) error { + switch string(input) { + case "null", "undefined": + *p = OptionalInteger{} + default: + var value int64 + err := json.Unmarshal(input, &value) + if err != nil { + return err + } + *p = OptionalInteger{value: &value} + } + return nil +} + +func (p OptionalInteger) String() string { + if p.value == nil { + return "default" + } + return fmt.Sprintf("%d", p.value) +} + +var _ json.Unmarshaler = (*OptionalInteger)(nil) +var _ json.Marshaler = (*OptionalInteger)(nil) diff --git a/types_test.go b/types_test.go index 8d4c62f..94a7a63 100644 --- a/types_test.go +++ b/types_test.go @@ -218,3 +218,93 @@ func TestPriority(t *testing.T) { } } } + +func TestOptionalInteger(t *testing.T) { + makeInt64Pointer := func(v int64) *int64 { + return &v + } + + var defaultOptionalInt OptionalInteger + if !defaultOptionalInt.IsDefault() { + t.Fatal("should be the default") + } + if val := defaultOptionalInt.WithDefault(0); val != 0 { + t.Errorf("optional integer should have been 0, got %d", val) + } + + if val := defaultOptionalInt.WithDefault(1); val != 1 { + t.Errorf("optional integer should have been 1, got %d", val) + } + + if val := defaultOptionalInt.WithDefault(-1); val != -1 { + t.Errorf("optional integer should have been -1, got %d", val) + } + + var filledInt OptionalInteger + filledInt = OptionalInteger{value: makeInt64Pointer(1)} + if filledInt.IsDefault() { + t.Fatal("should not be the default") + } + if val := filledInt.WithDefault(0); val != 1 { + t.Errorf("optional integer should have been 1, got %d", val) + } + + if val := filledInt.WithDefault(-1); val != 1 { + t.Errorf("optional integer should have been 1, got %d", val) + } + + filledInt = OptionalInteger{value: makeInt64Pointer(0)} + if val := filledInt.WithDefault(1); val != 0 { + t.Errorf("optional integer should have been 0, got %d", val) + } + + for jsonStr, goValue := range map[string]OptionalInteger{ + "null": {}, + "0": {value: makeInt64Pointer(0)}, + "1": {value: makeInt64Pointer(1)}, + "-1": {value: makeInt64Pointer(-1)}, + } { + var d OptionalInteger + err := json.Unmarshal([]byte(jsonStr), &d) + if err != nil { + t.Fatal(err) + } + + if goValue.value == nil && d.value == nil { + } else if goValue.value == nil && d.value != nil { + t.Errorf("expected default, got %s", d) + } else if *d.value != *goValue.value { + t.Fatalf("expected %s, got %s", goValue, d) + } + + // Reverse + out, err := json.Marshal(goValue) + if err != nil { + t.Fatal(err) + } + if string(out) != jsonStr { + t.Fatalf("expected %s, got %s", jsonStr, string(out)) + } + } + + type Foo struct { + I *OptionalInteger `json:",omitempty"` + } + out, err := json.Marshal(new(Foo)) + if err != nil { + t.Fatal(err) + } + expected := "{}" + if string(out) != expected { + t.Fatal("expected omitempty to omit the optional integer") + } + for _, invalid := range []string{ + "foo", "-1.1", "1.1", "0.0", "[]", + } { + var p Priority + err := json.Unmarshal([]byte(invalid), &p) + if err == nil { + t.Errorf("expected to fail to decode %s as a priority", invalid) + } + } +}