diff --git a/cel/env.go b/cel/env.go index 5cbb86a7..e8d6e2ea 100644 --- a/cel/env.go +++ b/cel/env.go @@ -47,16 +47,25 @@ type Ast struct { // Expr returns the proto serializable instance of the parsed/checked expression. func (ast *Ast) Expr() *exprpb.Expr { + if ast == nil { + return nil + } return ast.expr } // IsChecked returns whether the Ast value has been successfully type-checked. func (ast *Ast) IsChecked() bool { + if ast == nil { + return false + } return ast.typeMap != nil && len(ast.typeMap) > 0 } // SourceInfo returns character offset and newline position information about expression elements. func (ast *Ast) SourceInfo() *exprpb.SourceInfo { + if ast == nil { + return nil + } return ast.info } @@ -65,9 +74,6 @@ func (ast *Ast) SourceInfo() *exprpb.SourceInfo { // // Deprecated: use OutputType func (ast *Ast) ResultType() *exprpb.Type { - if !ast.IsChecked() { - return chkdecls.Dyn - } out := ast.OutputType() t, err := TypeToExprType(out) if err != nil { @@ -79,6 +85,9 @@ func (ast *Ast) ResultType() *exprpb.Type { // OutputType returns the output type of the expression if the Ast has been type-checked, else // returns cel.DynType as the parse step cannot infer types. func (ast *Ast) OutputType() *Type { + if ast == nil { + return types.ErrorType + } t, found := ast.typeMap[ast.expr.GetId()] if !found { return DynType @@ -89,6 +98,9 @@ func (ast *Ast) OutputType() *Type { // Source returns a view of the input used to create the Ast. This source may be complete or // constructed from the SourceInfo. func (ast *Ast) Source() Source { + if ast == nil { + return nil + } return ast.source } diff --git a/cel/env_test.go b/cel/env_test.go index c8f7ebee..66372796 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -29,6 +29,25 @@ import ( exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) +func TestAstNil(t *testing.T) { + var ast *Ast + if ast.IsChecked() { + t.Error("ast.IsChecked() returned true for nil ast") + } + if ast.Expr() != nil { + t.Errorf("ast.Expr() got %v, wanted nil", ast.Expr()) + } + if ast.SourceInfo() != nil { + t.Errorf("ast.SourceInfo() got %v, wanted nil", ast.SourceInfo()) + } + if ast.OutputType() != types.ErrorType { + t.Errorf("ast.OutputType() got %v, wanted error type", ast.OutputType()) + } + if ast.Source() != nil { + t.Errorf("ast.Source() got %v, wanted nil", ast.Source()) + } +} + func TestIssuesNil(t *testing.T) { var iss *Issues iss = iss.Append(iss)