Skip to content

Commit

Permalink
🎨Optimize UnmarshalJSON
Browse files Browse the repository at this point in the history
  • Loading branch information
haitwang-cloud committed May 6, 2024
1 parent e14655e commit e832823
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 44 deletions.
95 changes: 51 additions & 44 deletions api/config/v1/replicas.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,59 +269,66 @@ func (s *ReplicatedResource) UnmarshalJSON(b []byte) error {

// UnmarshalJSON unmarshals raw bytes into a 'ReplicatedDevices' struct.
func (s *ReplicatedDevices) UnmarshalJSON(b []byte) error {
// Match the string 'all'
var str string
err := json.Unmarshal(b, &str)
if err == nil {
if str != "all" {
return fmt.Errorf("devices set as '%v' but the only valid string input is 'all'", str)
}
s.All = true
return nil
var target interface{}
if err := json.Unmarshal(b, &target); err != nil {
return fmt.Errorf("unrecognized type for devices spec: %w", err)
}

// Match a count
var count int
err = json.Unmarshal(b, &count)
if err == nil {
if count <= 0 {
return fmt.Errorf("devices set as '%v' but a count of devices must be > 0", count)
switch t := target.(type) {
case string:
if err := handleStringInput(t, s); err != nil {
return err
}
s.Count = count
return nil
case float64:
if err := handleFloatInput(int(t), s); err != nil {
return err
}
case []interface{}:
if err := handleListInput(t, s); err != nil {
return err
}
default:
return fmt.Errorf("unsupported type for devices spec: %T", target)
}

// Match a list
var slice []json.RawMessage
err = json.Unmarshal(b, &slice)
if err == nil {
// For each item in the list check its format and convert it to a string (if necessary)
result := make([]ReplicatedDeviceRef, len(slice))
for i, s := range slice {
// Match a uint as a GPU index and convert it to a string
var index uint
if err = json.Unmarshal(s, &index); err == nil {
result[i] = ReplicatedDeviceRef(strconv.Itoa(int(index)))
continue
}
// Match strings as valid entries if they are GPU indices, MIG indices, or UUIDs
var item string
if err = json.Unmarshal(s, &item); err == nil {
rd := ReplicatedDeviceRef(item)
if rd.IsGPUIndex() || rd.IsMigIndex() || rd.IsUUID() {
result[i] = rd
continue
}
return nil
}

func handleStringInput(str string, s *ReplicatedDevices) error {
if str != "all" {
return fmt.Errorf("devices set as '%v' but the only valid string input is 'all'", str)
}
s.All = true
return nil
}

func handleFloatInput(count int, s *ReplicatedDevices) error {
if count <= 0 {
return fmt.Errorf("devices set as '%v' but a count of devices must be > 0", count)
}
s.Count = count
return nil
}

func handleListInput(items []interface{}, s *ReplicatedDevices) error {
result := make([]ReplicatedDeviceRef, len(items))
for i, item := range items {
switch v := item.(type) {
case float64:
result[i] = ReplicatedDeviceRef(strconv.Itoa(int(v)))
case string:
rd := ReplicatedDeviceRef(v)
if rd.IsGPUIndex() || rd.IsMigIndex() || rd.IsUUID() {
result[i] = rd
} else {
return fmt.Errorf("unsupported type for device in devices list: %v", v)
}
// Treat any other entries as errors
default:
return fmt.Errorf("unsupported type for device in devices list: %v, %T", item, item)
}
s.List = result
return nil
}

// No matches found
return fmt.Errorf("unrecognized type for devices spec: %v", string(b))
s.List = result
return nil
}

// MarshalJSON marshals ReplicatedDevices to its raw bytes representation
Expand Down
28 changes: 28 additions & 0 deletions api/config/v1/replicas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package v1

import (
"fmt"
"reflect"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -464,3 +465,30 @@ func TestUnmarshalReplicatedResources(t *testing.T) {
})
}
}
func TestUnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
want ReplicatedDevices
wantErr bool
}{
{"All devices", `"all"`, ReplicatedDevices{All: true}, false},
{"Count of devices", `3`, ReplicatedDevices{Count: 3}, false},
{"Devices list", `["0:1"]`, ReplicatedDevices{List: []ReplicatedDeviceRef{"0:1"}}, false},
{"Invalid datatype", `{"key":"value"}`, ReplicatedDevices{}, true},
{"Invalid device count", `0`, ReplicatedDevices{}, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got ReplicatedDevices
if err := got.UnmarshalJSON([]byte(tt.input)); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit e832823

Please sign in to comment.