diff --git a/pkg/util/tools.go b/pkg/util/tools.go index 8729585f86..497bd636df 100644 --- a/pkg/util/tools.go +++ b/pkg/util/tools.go @@ -18,7 +18,10 @@ limitations under the License. package util import ( + "fmt" "math" + "strconv" + "strings" "sync" "github.com/docker/distribution/reference" @@ -163,7 +166,7 @@ func CalculatePartitionReplicas(partition *intstrutil.IntOrString, replicasPoint } // 'roundUp=true' will ensure at least 1 old pod is reserved if partition > "0%" and replicas > 0. - pValue, err := intstrutil.GetScaledValueFromIntOrPercent(partition, replicas, true) + pValue, err := GetScaledValueFromIntOrPercent(partition, replicas, true) if err != nil { return pValue, err } @@ -189,3 +192,33 @@ func IsReferenceEqual(ref1, ref2 appsv1alpha1.TargetReference) bool { } return gv1.Group == gv2.Group && ref1.Kind == ref2.Kind && ref1.Name == ref2.Name } + +func GetScaledValueFromIntOrPercent(intOrPercent *intstrutil.IntOrString, total int, roundUp bool) (int, error) { + if intOrPercent == nil { + return 0, fmt.Errorf("nil value for IntOrString") + } + + switch intOrPercent.Type { + case intstrutil.Int: + return intOrPercent.IntValue(), nil + case intstrutil.String: + s := intOrPercent.StrVal + if strings.HasSuffix(s, "%") { + s = strings.TrimSuffix(intOrPercent.StrVal, "%") + } else { + return 0, fmt.Errorf("invalid type: string is not a percentage") + } + v, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0, err + } + var value int + if roundUp { + value = int(math.Ceil(v * (float64(total)) / 100)) + } else { + value = int(math.Floor(v * (float64(total)) / 100)) + } + return value, nil + } + return 0, fmt.Errorf("invalid type: neither int nor percentage") +} diff --git a/pkg/util/tools_test.go b/pkg/util/tools_test.go index 1cdf7a09fa..7d861d8b5e 100644 --- a/pkg/util/tools_test.go +++ b/pkg/util/tools_test.go @@ -299,3 +299,79 @@ func TestCalculatePartitionReplicas(t *testing.T) { }) } } + +func TestGetScaledValueFromIntOrPercent(t *testing.T) { + tests := []struct { + input intstr.IntOrString + total int + roundUp bool + expectErr bool + expectVal int + }{ + { + input: intstr.FromInt(123), + expectErr: false, + expectVal: 123, + }, + { + input: intstr.FromString("90%"), + total: 100, + roundUp: true, + expectErr: false, + expectVal: 90, + }, + { + input: intstr.FromString("90%"), + total: 95, + roundUp: true, + expectErr: false, + expectVal: 86, + }, + { + input: intstr.FromString("90%"), + total: 95, + roundUp: false, + expectErr: false, + expectVal: 85, + }, + { + input: intstr.FromString("99.99%"), + total: 95, + roundUp: false, + expectErr: false, + expectVal: 94, + }, + { + input: intstr.FromString("%"), + expectErr: true, + }, + { + input: intstr.FromString("90#"), + expectErr: true, + }, + { + input: intstr.FromString("#%"), + expectErr: true, + }, + { + input: intstr.FromString("90"), + expectErr: true, + }, + } + + for i, test := range tests { + t.Logf("test case %d", i) + value, err := GetScaledValueFromIntOrPercent(&test.input, test.total, test.roundUp) + if test.expectErr && err == nil { + t.Errorf("expected error, but got none") + continue + } + if !test.expectErr && err != nil { + t.Errorf("unexpected err: %v", err) + continue + } + if test.expectVal != value { + t.Errorf("expected %v, but got %v", test.expectVal, value) + } + } +} diff --git a/pkg/webhook/cloneset/validating/validation.go b/pkg/webhook/cloneset/validating/validation.go index 287d5d7a63..9a58078275 100644 --- a/pkg/webhook/cloneset/validating/validation.go +++ b/pkg/webhook/cloneset/validating/validation.go @@ -138,7 +138,7 @@ func (h *CloneSetCreateUpdateHandler) validateUpdateStrategy(strategy *appsv1alp appsv1alpha1.InPlaceOnlyCloneSetUpdateStrategyType))) } - partition, err := intstrutil.GetValueFromIntOrPercent(strategy.Partition, replicas, true) + partition, err := util.GetScaledValueFromIntOrPercent(strategy.Partition, replicas, true) if err != nil { allErrs = append(allErrs, field.Invalid(fldPath.Child("partition"), strategy.Partition.String(), fmt.Sprintf("failed getValueFromIntOrPercent for partition: %v", err)))