Skip to content

Commit

Permalink
function/stdlib: SortFunc always preserves the length of its input
Browse files Browse the repository at this point in the history
If the list argument is an unknown value or contains unknown values then
we can't possibly return a fully-known result, but we do at least know
that sorting will never change the number of elements and so we can refine
our unknown result using the range of the input value.

The refinements system automatically collapses an unknown list collection
whose upper and lower length bounds are equal into a known list where
all elements are unknown, so this automatically preserves the known-ness
of the input length in the case where we're given a known list with
unknown elements, without needing to handle that as a special case here.
  • Loading branch information
apparentlymart committed Feb 8, 2023
1 parent 199911c commit 6ff38c3
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 4 deletions.
18 changes: 14 additions & 4 deletions cty/function/stdlib/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,9 @@ var SortFunc = function.New(&function.Spec{
Description: "Applies a lexicographic sort to the elements of the given list.",
Params: []function.Parameter{
{
Name: "list",
Type: cty.List(cty.String),
Name: "list",
Type: cty.List(cty.String),
AllowUnknown: true,
},
},
Type: function.StaticReturnType(cty.List(cty.String)),
Expand All @@ -295,8 +296,17 @@ var SortFunc = function.New(&function.Spec{

if !listVal.IsWhollyKnown() {
// If some of the element values aren't known yet then we
// can't yet predict the order of the result.
return cty.UnknownVal(retType), nil
// can't yet predict the order of the result, but we can be
// sure that the length won't change.
ret := cty.UnknownVal(retType)
if listVal.Type().IsListType() {
rng := listVal.Range()
ret = ret.Refine().
CollectionLengthLowerBound(rng.LengthLowerBound()).
CollectionLengthUpperBound(rng.LengthUpperBound()).
NewValue()
}
return ret, nil
}
if listVal.LengthInt() == 0 { // Easy path
return listVal, nil
Expand Down
80 changes: 80 additions & 0 deletions cty/function/stdlib/string_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package stdlib

import (
"fmt"
"testing"

"github.com/zclconf/go-cty/cty"
Expand Down Expand Up @@ -477,3 +478,82 @@ func TestJoin(t *testing.T) {
})
}
}

func TestSort(t *testing.T) {
tests := []struct {
Input cty.Value
Want cty.Value
WantErr string
}{
{
cty.ListValEmpty(cty.String),
cty.ListValEmpty(cty.String),
``,
},
{
cty.ListVal([]cty.Value{cty.StringVal("a")}),
cty.ListVal([]cty.Value{cty.StringVal("a")}),
``,
},
{
cty.ListVal([]cty.Value{cty.StringVal("b"), cty.StringVal("a")}),
cty.ListVal([]cty.Value{cty.StringVal("a"), cty.StringVal("b")}),
``,
},
{
cty.ListVal([]cty.Value{cty.StringVal("b"), cty.StringVal("a"), cty.StringVal("c")}),
cty.ListVal([]cty.Value{cty.StringVal("a"), cty.StringVal("b"), cty.StringVal("c")}),
``,
},
{
cty.UnknownVal(cty.List(cty.String)),
cty.UnknownVal(cty.List(cty.String)).RefineNotNull(),
``,
},
{
// If the list contains any unknown values then we can still
// preserve the length of the list by generating a known list
// with unknown elements, because sort can never change the length.
cty.ListVal([]cty.Value{cty.StringVal("b"), cty.UnknownVal(cty.String)}),
cty.ListVal([]cty.Value{cty.UnknownVal(cty.String), cty.UnknownVal(cty.String)}),
``,
},
{
// For a completely unknown list we can still preserve any
// refinements it had for its length, because sorting can never
// change the length.
cty.UnknownVal(cty.List(cty.String)).Refine().
CollectionLengthLowerBound(1).
CollectionLengthUpperBound(2).
NewValue(),
cty.UnknownVal(cty.List(cty.String)).Refine().
NotNull().
CollectionLengthLowerBound(1).
CollectionLengthUpperBound(2).
NewValue(),
``,
},
}

for _, test := range tests {
t.Run(fmt.Sprintf("Sort(%#v)", test.Input), func(t *testing.T) {
got, err := Sort(test.Input)

if test.WantErr != "" {
errStr := fmt.Sprintf("%s", err)
if errStr != test.WantErr {
t.Errorf("wrong error\ngot: %s\nwant: %s", errStr, test.WantErr)
}
return
}

if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}

if !got.RawEquals(test.Want) {
t.Errorf("wrong result\ninput: %#v\ngot: %#v\nwant: %#v", test.Input, got, test.Want)
}
})
}
}

0 comments on commit 6ff38c3

Please sign in to comment.