Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Simplify UnmarshalJSON with Dedicated Handlers #695

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 44 additions & 42 deletions api/config/v1/replicas.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,59 +269,61 @@ func (s *ReplicatedResource) UnmarshalJSON(b []byte) error {

// UnmarshalJSON unmarshals raw bytes into a 'ReplicatedDevices' struct.
func (s *ReplicatedDevices) UnmarshalJSON(b []byte) error {
// Match the string 'all'
var str string
err := json.Unmarshal(b, &str)
if err == nil {
if str != "all" {
return fmt.Errorf("devices set as '%v' but the only valid string input is 'all'", str)
}
s.All = true
return nil
if err := json.Unmarshal(b, &str); err == nil {
return handleStringInput(str, s)
}

// Match a count
var count int
err = json.Unmarshal(b, &count)
if err == nil {
if count <= 0 {
return fmt.Errorf("devices set as '%v' but a count of devices must be > 0", count)
}
s.Count = count
return nil
if err := json.Unmarshal(b, &count); err == nil {
return handleIntInput(count, s)
}

// Match a list
var slice []json.RawMessage
err = json.Unmarshal(b, &slice)
if err == nil {
// For each item in the list check its format and convert it to a string (if necessary)
result := make([]ReplicatedDeviceRef, len(slice))
for i, s := range slice {
// Match a uint as a GPU index and convert it to a string
var index uint
if err = json.Unmarshal(s, &index); err == nil {
result[i] = ReplicatedDeviceRef(strconv.Itoa(int(index)))
if err := json.Unmarshal(b, &slice); err == nil {
return handleListInput(slice, s)
}

return fmt.Errorf("unrecognized type for devices spec: %s", string(b))
}

func handleStringInput(str string, s *ReplicatedDevices) error {
if str != "all" {
return fmt.Errorf("devices set as '%v' but the only valid string input is 'all'", str)
}
s.All = true
return nil
}

func handleIntInput(count int, s *ReplicatedDevices) error {
if count <= 0 {
return fmt.Errorf("devices set as '%v' but a count of devices must be > 0", count)
}
s.Count = count
return nil
}

func handleListInput(slice []json.RawMessage, s *ReplicatedDevices) error {
result := make([]ReplicatedDeviceRef, len(slice))
for i, raw := range slice {
var index uint
if err := json.Unmarshal(raw, &index); err == nil {
result[i] = ReplicatedDeviceRef(strconv.Itoa(int(index)))
continue
}

var item string
if err := json.Unmarshal(raw, &item); err == nil {
rd := ReplicatedDeviceRef(item)
if rd.IsGPUIndex() || rd.IsMigIndex() || rd.IsUUID() {
result[i] = rd
continue
}
// Match strings as valid entries if they are GPU indices, MIG indices, or UUIDs
var item string
if err = json.Unmarshal(s, &item); err == nil {
rd := ReplicatedDeviceRef(item)
if rd.IsGPUIndex() || rd.IsMigIndex() || rd.IsUUID() {
result[i] = rd
continue
}
}
// Treat any other entries as errors
return fmt.Errorf("unsupported type for device in devices list: %v, %T", item, item)
}
s.List = result
return nil
return fmt.Errorf("unsupported type for device in devices list: %v, %T", item, item)
}

// No matches found
return fmt.Errorf("unrecognized type for devices spec: %v", string(b))
s.List = result
return nil
}

// MarshalJSON marshals ReplicatedDevices to its raw bytes representation
Expand Down
28 changes: 28 additions & 0 deletions api/config/v1/replicas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package v1

import (
"fmt"
"reflect"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -464,3 +465,30 @@ func TestUnmarshalReplicatedResources(t *testing.T) {
})
}
}
func TestUnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
want ReplicatedDevices
wantErr bool
}{
{"All devices", `"all"`, ReplicatedDevices{All: true}, false},
{"Count of devices", `3`, ReplicatedDevices{Count: 3}, false},
{"Devices list", `["0:1"]`, ReplicatedDevices{List: []ReplicatedDeviceRef{"0:1"}}, false},
{"Invalid datatype", `{"key":"value"}`, ReplicatedDevices{}, true},
{"Invalid device count", `0`, ReplicatedDevices{}, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got ReplicatedDevices
if err := got.UnmarshalJSON([]byte(tt.input)); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want)
}
})
}
}