diff --git a/scw/custom_types.go b/scw/custom_types.go index b7dfe8231..5d723b370 100644 --- a/scw/custom_types.go +++ b/scw/custom_types.go @@ -5,9 +5,8 @@ import ( "encoding/json" "fmt" "io" + "net" "time" - - "github.com/scaleway/scaleway-sdk-go/internal/errors" ) // ServiceInfo contains API metadata @@ -123,7 +122,7 @@ type TimeSeriesPoint struct { Value float32 } -func (tsp *TimeSeriesPoint) MarshalJSON() ([]byte, error) { +func (tsp TimeSeriesPoint) MarshalJSON() ([]byte, error) { timestamp := tsp.Timestamp.Format(time.RFC3339) value, err := json.Marshal(tsp.Value) if err != nil { @@ -142,25 +141,66 @@ func (tsp *TimeSeriesPoint) UnmarshalJSON(b []byte) error { } if len(point) != 2 { - return errors.New("invalid point array") + return fmt.Errorf("invalid point array") } strTimestamp, isStrTimestamp := point[0].(string) if !isStrTimestamp { - return errors.New("%s timestamp is not a string in RFC 3339 format", point[0]) + return fmt.Errorf("%s timestamp is not a string in RFC 3339 format", point[0]) } timestamp, err := time.Parse(time.RFC3339, strTimestamp) if err != nil { - return errors.New("%s timestamp is not in RFC 3339 format", point[0]) + return fmt.Errorf("%s timestamp is not in RFC 3339 format", point[0]) } tsp.Timestamp = timestamp // By default, JSON unmarshal a float in float64 but the TimeSeriesPoint is a float32 value. value, isValue := point[1].(float64) if !isValue { - return errors.New("%s is not a valid float32 value", point[1]) + return fmt.Errorf("%s is not a valid float32 value", point[1]) } tsp.Value = float32(value) return nil } + +// IPNet inherits net.IPNet and represents an IP network. +type IPNet struct { + net.IPNet +} + +func (n IPNet) MarshalJSON() ([]byte, error) { + value := n.String() + if value == "" { + value = "" + } + return []byte(`"` + value + `"`), nil +} + +func (n *IPNet) UnmarshalJSON(b []byte) error { + var str string + + err := json.Unmarshal(b, &str) + if err != nil { + return err + } + if str == "" { + *n = IPNet{} + return nil + } + + switch ip := net.ParseIP(str); { + case ip.To4() != nil: + str += "/32" + case ip.To16() != nil: + str += "/128" + } + + _, value, err := net.ParseCIDR(str) + if err != nil { + return err + } + n.IPNet = *value + + return nil +} diff --git a/scw/custom_types_test.go b/scw/custom_types_test.go index d57ca9466..587ee8c04 100644 --- a/scw/custom_types_test.go +++ b/scw/custom_types_test.go @@ -2,11 +2,12 @@ package scw import ( "encoding/json" + "fmt" "io/ioutil" + "net" "testing" "time" - "github.com/scaleway/scaleway-sdk-go/internal/errors" "github.com/scaleway/scaleway-sdk-go/internal/testhelpers" ) @@ -67,7 +68,7 @@ func TestTimeSeries_MarshallJSON(t *testing.T) { } } -func TestTimeSeries_UnmarshallJSON(t *testing.T) { +func TestTimeSeries_UnmarshalJSON(t *testing.T) { cases := []struct { name string json string @@ -108,7 +109,7 @@ func TestTimeSeries_UnmarshallJSON(t *testing.T) { { name: "with timestamp error", json: `{"name":"cpu_usage","points":[["2019/08/08T15-00-00Z",0.2]]}`, - err: errors.New("2019/08/08T15-00-00Z timestamp is not in RFC 3339 format"), + err: fmt.Errorf("2019/08/08T15-00-00Z timestamp is not in RFC 3339 format"), }, } @@ -166,3 +167,86 @@ func TestFile_UnmarshalJSON(t *testing.T) { content: []byte("\x00\x00\x00\n"), })) } + +func TestIPNet_MarshallJSON(t *testing.T) { + cases := []struct { + name string + ipRange IPNet + want string + err error + }{ + { + name: "ip", + ipRange: IPNet{IPNet: net.IPNet{IP: net.IPv4(42, 42, 42, 42), Mask: net.CIDRMask(32, 32)}}, + want: `"42.42.42.42/32"`, + }, + { + name: "network", + ipRange: IPNet{IPNet: net.IPNet{IP: net.IPv4(42, 42, 42, 42), Mask: net.CIDRMask(16, 32)}}, + want: `"42.42.42.42/16"`, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := json.Marshal(c.ipRange) + + testhelpers.Equals(t, c.err, err) + if c.err == nil { + testhelpers.Equals(t, c.want, string(got)) + } + }) + } +} + +func TestIPNet_UnmarshalJSON(t *testing.T) { + cases := []struct { + name string + json string + want IPNet + err string + }{ + { + name: "IPv4 with CIDR", + json: `"42.42.42.42/32"`, + want: IPNet{IPNet: net.IPNet{IP: net.IPv4(42, 42, 42, 42), Mask: net.CIDRMask(32, 32)}}, + }, + { + name: "IPv4 with network", + json: `"192.0.2.1/24"`, + want: IPNet{IPNet: net.IPNet{IP: net.IPv4(192, 0, 2, 0), Mask: net.CIDRMask(24, 32)}}, + }, + { + name: "IPv6 with network", + json: `"2001:db8:abcd:8000::/50"`, + want: IPNet{IPNet: net.IPNet{IP: net.ParseIP("2001:db8:abcd:8000::"), Mask: net.CIDRMask(50, 128)}}, + }, + { + name: "IPv4 alone", + json: `"42.42.42.42"`, + want: IPNet{IPNet: net.IPNet{IP: net.IPv4(42, 42, 42, 42), Mask: net.CIDRMask(32, 32)}}, + }, + { + name: "IPv6 alone", + json: `"2001:db8:abcd:8000::"`, + want: IPNet{IPNet: net.IPNet{IP: net.ParseIP("2001:db8:abcd:8000::"), Mask: net.CIDRMask(128, 128)}}, + }, + { + name: "invalid CIDR error", + json: `"invalidvalue"`, + err: "invalid CIDR address: invalidvalue", + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + ipNet := &IPNet{} + err := json.Unmarshal([]byte(c.json), ipNet) + if err != nil { + testhelpers.Equals(t, c.err, err.Error()) + } + + testhelpers.Equals(t, c.want.String(), ipNet.String()) + }) + } +}