From 933f926a7fbc21d664a2894388e7a7811ae2ffcb Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 22 Nov 2024 08:55:42 -0800 Subject: [PATCH] Fix nil-type when two-var comprehension has a dyn range (#1077) --- checker/checker.go | 6 ++++++ ext/comprehensions_test.go | 7 +++++++ policy/compiler_test.go | 4 ++++ 3 files changed, 17 insertions(+) diff --git a/checker/checker.go b/checker/checker.go index 0603cfa3..6824af7a 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -529,9 +529,15 @@ func (c *checker) checkComprehension(e ast.Expr) { c.isAssignable(types.DynType, rangeType) // Set the range iteration variable to type DYN as well. varType = types.DynType + if comp.HasIterVar2() { + var2Type = types.DynType + } default: c.errors.notAComprehensionRange(comp.IterRange().ID(), c.location(comp.IterRange()), rangeType) varType = types.ErrorType + if comp.HasIterVar2() { + var2Type = types.ErrorType + } } // Create a block scope for the loop. diff --git a/ext/comprehensions_test.go b/ext/comprehensions_test.go index 84d82c37..6416e0e7 100644 --- a/ext/comprehensions_test.go +++ b/ext/comprehensions_test.go @@ -139,6 +139,9 @@ func TestTwoVarComprehensions(t *testing.T) { {'Hello': 'world'}.transformList(k, v, "%s=%s".format([k.lowerAscii(), v])) == ["hello=world"] `}, {expr: ` + dyn({'Hello': 'world'}).transformList(k, v, "%s=%s".format([k.lowerAscii(), v])) == ["hello=world"] + `}, + {expr: ` {'hello': 'world'}.transformList(k, v, k.startsWith('greeting'), "%s=%s".format([k, v])) == [] `}, {expr: ` @@ -155,6 +158,10 @@ func TestTwoVarComprehensions(t *testing.T) { == {'hello': 'hello, world!', 'goodbye': 'goodbye, cruel world!'} `}, {expr: ` + dyn({'hello': 'world', 'goodbye': 'cruel world'}).transformMap(k, v, "%s, %s!".format([k, v])) + == {'hello': 'hello, world!', 'goodbye': 'goodbye, cruel world!'} + `}, + {expr: ` {'hello': 'world', 'goodbye': 'cruel world'}.transformMap(k, v, v.startsWith('world'), "%s, %s!".format([k, v])) == {'hello': 'hello, world!'} `}, diff --git a/policy/compiler_test.go b/policy/compiler_test.go index 5865f524..7aa76bde 100644 --- a/policy/compiler_test.go +++ b/policy/compiler_test.go @@ -192,6 +192,10 @@ func (r *runner) setup(t testing.TB) { if err != nil { t.Fatalf("cel.AstToString() failed: %v", err) } + _, err = cel.AstToCheckedExpr(ast) + if err != nil { + t.Fatalf("cel.AstToCheckedExpr() failed: %v", err) + } if r.expr != "" && normalize(pExpr) != normalize(r.expr) { t.Errorf("cel.AstToString() got %s, wanted %s", pExpr, r.expr) }