Skip to content

Commit

Permalink
cty: Refine the ranges of arithmetic results
Browse files Browse the repository at this point in the history
If we are performing addition, subtraction, or multiplication on unknown
numbers with known numeric bounds then we can propagate bounds to the
result by performing interval arithmetic.

This is not as complete as it could be because of trying to share a single
implementation across all of the functions while still dealing with all
of their panic edge cases.
  • Loading branch information
apparentlymart committed Feb 4, 2023
1 parent 448ca74 commit 7416265
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 14 deletions.
4 changes: 2 additions & 2 deletions cty/msgpack/unknown.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func marshalUnknownValue(rng cty.ValueRange, path cty.Path, enc *msgpack.Encoder
lower, lowerInc := rng.NumberLowerBound()
upper, upperInc := rng.NumberUpperBound()
boundTy := cty.Tuple([]cty.Type{cty.Number, cty.Bool})
if lower.IsKnown() {
if lower.IsKnown() && lower != cty.NegativeInfinity {
mapLen++
refnEnc.EncodeInt(int64(unknownValNumberMin))
marshal(
Expand All @@ -73,7 +73,7 @@ func marshalUnknownValue(rng cty.ValueRange, path cty.Path, enc *msgpack.Encoder
refnEnc,
)
}
if upper.IsKnown() {
if upper.IsKnown() && upper != cty.PositiveInfinity {
mapLen++
refnEnc.EncodeInt(int64(unknownValNumberMax))
marshal(
Expand Down
4 changes: 2 additions & 2 deletions cty/unknown_refinement.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,10 +632,10 @@ func (r *refinementNumber) rawEqual(other unknownValRefinement) bool {
func (r *refinementNumber) GoString() string {
var b strings.Builder
b.WriteString(r.refinementNullable.GoString())
if r.min != NilVal {
if r.min != NilVal && r.min != NegativeInfinity {
fmt.Fprintf(&b, ".NumberLowerBound(%#v, %t)", r.min, r.minInc)
}
if r.max != NilVal {
if r.max != NilVal && r.max != PositiveInfinity {
fmt.Fprintf(&b, ".NumberUpperBound(%#v, %t)", r.max, r.maxInc)
}
return b.String()
Expand Down
12 changes: 9 additions & 3 deletions cty/value_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,8 @@ func (val Value) Add(other Value) Value {

if shortCircuit := mustTypeCheck(Number, Number, val, other); shortCircuit != nil {
shortCircuit = forceShortCircuitType(shortCircuit, Number)
return (*shortCircuit).RefineNotNull()
ret := shortCircuit.RefineWith(numericRangeArithmetic(Value.Add, val.Range(), other.Range()))
return ret.RefineNotNull()
}

ret := new(big.Float)
Expand All @@ -612,7 +613,8 @@ func (val Value) Subtract(other Value) Value {

if shortCircuit := mustTypeCheck(Number, Number, val, other); shortCircuit != nil {
shortCircuit = forceShortCircuitType(shortCircuit, Number)
return (*shortCircuit).RefineNotNull()
ret := shortCircuit.RefineWith(numericRangeArithmetic(Value.Subtract, val.Range(), other.Range()))
return ret.RefineNotNull()
}

return val.Add(other.Negate())
Expand Down Expand Up @@ -646,7 +648,8 @@ func (val Value) Multiply(other Value) Value {

if shortCircuit := mustTypeCheck(Number, Number, val, other); shortCircuit != nil {
shortCircuit = forceShortCircuitType(shortCircuit, Number)
return (*shortCircuit).RefineNotNull()
ret := shortCircuit.RefineWith(numericRangeArithmetic(Value.Multiply, val.Range(), other.Range()))
return ret.RefineNotNull()
}

// find the larger precision of the arguments
Expand Down Expand Up @@ -691,6 +694,9 @@ func (val Value) Divide(other Value) Value {

if shortCircuit := mustTypeCheck(Number, Number, val, other); shortCircuit != nil {
shortCircuit = forceShortCircuitType(shortCircuit, Number)
// TODO: We could potentially refine the range of the result here, but
// we don't right now because our division operation is not monotone
// if the denominator could potentially be zero.
return (*shortCircuit).RefineNotNull()
}

Expand Down
192 changes: 187 additions & 5 deletions cty/value_ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1767,6 +1767,66 @@ func TestValueAdd(t *testing.T) {
UnknownVal(Number),
UnknownVal(Number).RefineNotNull(),
},
{
NumberIntVal(1),
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeLowerBound(NumberIntVal(3), true).
NewValue(),
},
{
Zero,
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeLowerBound(NumberIntVal(2), true).
NewValue(),
},
{
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeLowerBound(NumberIntVal(4), true).
NewValue(),
},
{
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(1), true).
NumberRangeUpperBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeLowerBound(NumberIntVal(3), true).
NewValue(),
},
{
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(1), true).
NumberRangeUpperBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), false).
NumberRangeUpperBound(NumberIntVal(3), false).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeLowerBound(NumberIntVal(3), true).
NumberRangeUpperBound(NumberIntVal(5), true).
NewValue(),
},
{
UnknownVal(Number),
UnknownVal(Number),
Expand Down Expand Up @@ -1803,7 +1863,7 @@ func TestValueAdd(t *testing.T) {
t.Run(fmt.Sprintf("%#v.Add(%#v)", test.LHS, test.RHS), func(t *testing.T) {
got := test.LHS.Add(test.RHS)
if !got.RawEquals(test.Expected) {
t.Fatalf("Add returned %#v; want %#v", got, test.Expected)
t.Fatalf("Wrong result\ngot: %#v\nwant: %#v", got, test.Expected)
}
})
}
Expand Down Expand Up @@ -1840,6 +1900,63 @@ func TestValueSubtract(t *testing.T) {
UnknownVal(Number),
UnknownVal(Number).RefineNotNull(),
},
{
NumberIntVal(1),
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), true).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeUpperBound(NumberIntVal(-1), true).
NewValue(),
},
{
Zero,
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), true).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeUpperBound(NumberIntVal(-2), true).
NewValue(),
},
{
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), true).
NewValue(),
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), true).
NewValue(),
UnknownVal(Number).RefineNotNull(), // We don't currently refine this case
},
{
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(1), true).
NumberRangeUpperBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), true).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeUpperBound(NumberIntVal(0), true).
NewValue(),
},
{
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(1), true).
NumberRangeUpperBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), false).
NumberRangeUpperBound(NumberIntVal(3), false).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeLowerBound(NumberIntVal(-2), true).
NumberRangeUpperBound(NumberIntVal(0), true).
NewValue(),
},
{
NumberIntVal(1),
DynamicVal,
Expand Down Expand Up @@ -1871,7 +1988,7 @@ func TestValueSubtract(t *testing.T) {
t.Run(fmt.Sprintf("%#v.Subtract(%#v)", test.LHS, test.RHS), func(t *testing.T) {
got := test.LHS.Subtract(test.RHS)
if !got.RawEquals(test.Expected) {
t.Fatalf("Subtract returned %#v; want %#v", got, test.Expected)
t.Fatalf("wrong result\ngot: %#v\nwant: %#v", got, test.Expected)
}
})
}
Expand Down Expand Up @@ -1908,7 +2025,7 @@ func TestValueNegate(t *testing.T) {
t.Run(fmt.Sprintf("%#v.Negate()", test.Receiver), func(t *testing.T) {
got := test.Receiver.Negate()
if !got.RawEquals(test.Expected) {
t.Fatalf("Negate returned %#v; want %#v", got, test.Expected)
t.Fatalf("wrong result\ngot: %#v\nwant: %#v", got, test.Expected)
}
})
}
Expand Down Expand Up @@ -1945,6 +2062,71 @@ func TestValueMultiply(t *testing.T) {
UnknownVal(Number),
UnknownVal(Number).RefineNotNull(),
},
{
NumberIntVal(3),
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeLowerBound(NumberIntVal(6), true).
NewValue(),
},
{
Zero,
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).RefineNotNull(), // We can't currently refine this case
},
{
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(4), false).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeLowerBound(NumberIntVal(8), true).
NewValue(),
},
{
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(3), true).
NumberRangeUpperBound(NumberIntVal(4), false).
NewValue(),
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeLowerBound(NumberIntVal(6), true).
NewValue(),
},
{
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(1), true).
NumberRangeUpperBound(NumberIntVal(2), false).
NewValue(),
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(2), false).
NumberRangeUpperBound(NumberIntVal(3), false).
NewValue(),
UnknownVal(Number).Refine().
NotNull().
NumberRangeLowerBound(NumberIntVal(2), true).
NumberRangeUpperBound(NumberIntVal(6), true).
NewValue(),
},
{
UnknownVal(Number).Refine().
NumberRangeLowerBound(NumberIntVal(1), true).
NumberRangeUpperBound(NumberIntVal(2), false).
NewValue(),
Zero,
Zero, // deduced by refinement
},
{
NumberIntVal(1),
DynamicVal,
Expand Down Expand Up @@ -1986,7 +2168,7 @@ func TestValueMultiply(t *testing.T) {
t.Run(fmt.Sprintf("%#v.Multiply(%#v)", test.LHS, test.RHS), func(t *testing.T) {
got := test.LHS.Multiply(test.RHS)
if !got.RawEquals(test.Expected) {
t.Fatalf("Multiply returned %#v; want %#v", got, test.Expected)
t.Fatalf("wrong result\ngot: %#v\nwant: %#v", got, test.Expected)
}
})
}
Expand Down Expand Up @@ -2064,7 +2246,7 @@ func TestValueDivide(t *testing.T) {
t.Run(fmt.Sprintf("%#v.Divide(%#v)", test.LHS, test.RHS), func(t *testing.T) {
got := test.LHS.Divide(test.RHS)
if !got.RawEquals(test.Expected) {
t.Fatalf("Divide returned %#v; want %#v", got, test.Expected)
t.Fatalf("wrong result\ngot: %#v\nwant: %#v", got, test.Expected)
}
})
}
Expand Down
Loading

0 comments on commit 7416265

Please sign in to comment.