diff --git a/api/config/v1/replicas.go b/api/config/v1/replicas.go index c24c7109c..6c404446b 100644 --- a/api/config/v1/replicas.go +++ b/api/config/v1/replicas.go @@ -269,59 +269,61 @@ func (s *ReplicatedResource) UnmarshalJSON(b []byte) error { // UnmarshalJSON unmarshals raw bytes into a 'ReplicatedDevices' struct. func (s *ReplicatedDevices) UnmarshalJSON(b []byte) error { - // Match the string 'all' var str string - err := json.Unmarshal(b, &str) - if err == nil { - if str != "all" { - return fmt.Errorf("devices set as '%v' but the only valid string input is 'all'", str) - } - s.All = true - return nil + if err := json.Unmarshal(b, &str); err == nil { + return handleStringInput(str, s) } - // Match a count var count int - err = json.Unmarshal(b, &count) - if err == nil { - if count <= 0 { - return fmt.Errorf("devices set as '%v' but a count of devices must be > 0", count) - } - s.Count = count - return nil + if err := json.Unmarshal(b, &count); err == nil { + return handleIntInput(count, s) } - // Match a list var slice []json.RawMessage - err = json.Unmarshal(b, &slice) - if err == nil { - // For each item in the list check its format and convert it to a string (if necessary) - result := make([]ReplicatedDeviceRef, len(slice)) - for i, s := range slice { - // Match a uint as a GPU index and convert it to a string - var index uint - if err = json.Unmarshal(s, &index); err == nil { - result[i] = ReplicatedDeviceRef(strconv.Itoa(int(index))) + if err := json.Unmarshal(b, &slice); err == nil { + return handleListInput(slice, s) + } + + return fmt.Errorf("unrecognized type for devices spec: %s", string(b)) +} + +func handleStringInput(str string, s *ReplicatedDevices) error { + if str != "all" { + return fmt.Errorf("devices set as '%v' but the only valid string input is 'all'", str) + } + s.All = true + return nil +} + +func handleIntInput(count int, s *ReplicatedDevices) error { + if count <= 0 { + return fmt.Errorf("devices set as '%v' but a count of devices must be > 0", count) + } + s.Count = count + return nil +} + +func handleListInput(slice []json.RawMessage, s *ReplicatedDevices) error { + result := make([]ReplicatedDeviceRef, len(slice)) + for i, raw := range slice { + var index uint + if err := json.Unmarshal(raw, &index); err == nil { + result[i] = ReplicatedDeviceRef(strconv.Itoa(int(index))) + continue + } + + var item string + if err := json.Unmarshal(raw, &item); err == nil { + rd := ReplicatedDeviceRef(item) + if rd.IsGPUIndex() || rd.IsMigIndex() || rd.IsUUID() { + result[i] = rd continue } - // Match strings as valid entries if they are GPU indices, MIG indices, or UUIDs - var item string - if err = json.Unmarshal(s, &item); err == nil { - rd := ReplicatedDeviceRef(item) - if rd.IsGPUIndex() || rd.IsMigIndex() || rd.IsUUID() { - result[i] = rd - continue - } - } - // Treat any other entries as errors - return fmt.Errorf("unsupported type for device in devices list: %v, %T", item, item) } - s.List = result - return nil + return fmt.Errorf("unsupported type for device in devices list: %v, %T", item, item) } - - // No matches found - return fmt.Errorf("unrecognized type for devices spec: %v", string(b)) + s.List = result + return nil } // MarshalJSON marshals ReplicatedDevices to its raw bytes representation diff --git a/api/config/v1/replicas_test.go b/api/config/v1/replicas_test.go index 8262c4376..e8eae3d40 100644 --- a/api/config/v1/replicas_test.go +++ b/api/config/v1/replicas_test.go @@ -18,6 +18,7 @@ package v1 import ( "fmt" + "reflect" "testing" "github.com/stretchr/testify/require" @@ -464,3 +465,30 @@ func TestUnmarshalReplicatedResources(t *testing.T) { }) } } +func TestUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + want ReplicatedDevices + wantErr bool + }{ + {"All devices", `"all"`, ReplicatedDevices{All: true}, false}, + {"Count of devices", `3`, ReplicatedDevices{Count: 3}, false}, + {"Devices list", `["0:1"]`, ReplicatedDevices{List: []ReplicatedDeviceRef{"0:1"}}, false}, + {"Invalid datatype", `{"key":"value"}`, ReplicatedDevices{}, true}, + {"Invalid device count", `0`, ReplicatedDevices{}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got ReplicatedDevices + if err := got.UnmarshalJSON([]byte(tt.input)); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +}