diff --git a/cty/function/stdlib/string.go b/cty/function/stdlib/string.go index 04033026..57ebce1b 100644 --- a/cty/function/stdlib/string.go +++ b/cty/function/stdlib/string.go @@ -284,8 +284,9 @@ var SortFunc = function.New(&function.Spec{ Description: "Applies a lexicographic sort to the elements of the given list.", Params: []function.Parameter{ { - Name: "list", - Type: cty.List(cty.String), + Name: "list", + Type: cty.List(cty.String), + AllowUnknown: true, }, }, Type: function.StaticReturnType(cty.List(cty.String)), @@ -295,8 +296,17 @@ var SortFunc = function.New(&function.Spec{ if !listVal.IsWhollyKnown() { // If some of the element values aren't known yet then we - // can't yet predict the order of the result. - return cty.UnknownVal(retType), nil + // can't yet predict the order of the result, but we can be + // sure that the length won't change. + ret := cty.UnknownVal(retType) + if listVal.Type().IsListType() { + rng := listVal.Range() + ret = ret.Refine(). + CollectionLengthLowerBound(rng.LengthLowerBound()). + CollectionLengthUpperBound(rng.LengthUpperBound()). + NewValue() + } + return ret, nil } if listVal.LengthInt() == 0 { // Easy path return listVal, nil diff --git a/cty/function/stdlib/string_test.go b/cty/function/stdlib/string_test.go index 3f27addb..254a2e73 100644 --- a/cty/function/stdlib/string_test.go +++ b/cty/function/stdlib/string_test.go @@ -1,6 +1,7 @@ package stdlib import ( + "fmt" "testing" "github.com/zclconf/go-cty/cty" @@ -477,3 +478,82 @@ func TestJoin(t *testing.T) { }) } } + +func TestSort(t *testing.T) { + tests := []struct { + Input cty.Value + Want cty.Value + WantErr string + }{ + { + cty.ListValEmpty(cty.String), + cty.ListValEmpty(cty.String), + ``, + }, + { + cty.ListVal([]cty.Value{cty.StringVal("a")}), + cty.ListVal([]cty.Value{cty.StringVal("a")}), + ``, + }, + { + cty.ListVal([]cty.Value{cty.StringVal("b"), cty.StringVal("a")}), + cty.ListVal([]cty.Value{cty.StringVal("a"), cty.StringVal("b")}), + ``, + }, + { + cty.ListVal([]cty.Value{cty.StringVal("b"), cty.StringVal("a"), cty.StringVal("c")}), + cty.ListVal([]cty.Value{cty.StringVal("a"), cty.StringVal("b"), cty.StringVal("c")}), + ``, + }, + { + cty.UnknownVal(cty.List(cty.String)), + cty.UnknownVal(cty.List(cty.String)).RefineNotNull(), + ``, + }, + { + // If the list contains any unknown values then we can still + // preserve the length of the list by generating a known list + // with unknown elements, because sort can never change the length. + cty.ListVal([]cty.Value{cty.StringVal("b"), cty.UnknownVal(cty.String)}), + cty.ListVal([]cty.Value{cty.UnknownVal(cty.String), cty.UnknownVal(cty.String)}), + ``, + }, + { + // For a completely unknown list we can still preserve any + // refinements it had for its length, because sorting can never + // change the length. + cty.UnknownVal(cty.List(cty.String)).Refine(). + CollectionLengthLowerBound(1). + CollectionLengthUpperBound(2). + NewValue(), + cty.UnknownVal(cty.List(cty.String)).Refine(). + NotNull(). + CollectionLengthLowerBound(1). + CollectionLengthUpperBound(2). + NewValue(), + ``, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("Sort(%#v)", test.Input), func(t *testing.T) { + got, err := Sort(test.Input) + + if test.WantErr != "" { + errStr := fmt.Sprintf("%s", err) + if errStr != test.WantErr { + t.Errorf("wrong error\ngot: %s\nwant: %s", errStr, test.WantErr) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + if !got.RawEquals(test.Want) { + t.Errorf("wrong result\ninput: %#v\ngot: %#v\nwant: %#v", test.Input, got, test.Want) + } + }) + } +}