From 79641b8126ed3f65cf66c374c6ccbdca9f40ae3a Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Fri, 5 May 2023 13:30:30 -0700 Subject: [PATCH] Make checker state concurrency safe --- cel/env.go | 55 ++++++++++++++++++++++++++++++++----------------- cel/env_test.go | 27 ++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 19 deletions(-) diff --git a/cel/env.go b/cel/env.go index 8cf442ee..9dc0ff5b 100644 --- a/cel/env.go +++ b/cel/env.go @@ -109,10 +109,11 @@ type Env struct { prsrOpts []parser.Option // Internal checker representation - chk *checker.Env - chkErr error - chkOnce sync.Once - chkOpts []checker.Option + chkMutex sync.Mutex + chk *checker.Env + chkErr error + chkOnce sync.Once + chkOpts []checker.Option // Program options tied to the environment progOpts []ProgramOption @@ -178,14 +179,14 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) { pe, _ := AstToParsedExpr(ast) // Construct the internal checker env, erroring if there is an issue adding the declarations. - err := e.initChecker() + chk, err := e.initChecker() if err != nil { errs := common.NewErrors(ast.Source()) - errs.ReportError(common.NoLocation, e.chkErr.Error()) + errs.ReportError(common.NoLocation, err.Error()) return nil, NewIssues(errs) } - res, errs := checker.Check(pe, ast.Source(), e.chk) + res, errs := checker.Check(pe, ast.Source(), chk) if len(errs.GetErrors()) > 0 { return nil, NewIssues(errs) } @@ -239,8 +240,9 @@ func (e *Env) CompileSource(src Source) (*Ast, *Issues) { // TypeProvider are immutable, or that their underlying implementations are based on the // ref.TypeRegistry which provides a Copy method which will be invoked by this method. func (e *Env) Extend(opts ...EnvOption) (*Env, error) { - if e.chkErr != nil { - return nil, e.chkErr + chk, chkErr := e.getCheckerOrError() + if chkErr != nil { + return nil, chkErr } prsrOptsCopy := make([]parser.Option, len(e.prsrOpts)) @@ -254,10 +256,10 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) { // Copy the declarations if needed. decsCopy := []*exprpb.Decl{} - if e.chk != nil { + if chk != nil { // If the type-checker has already been instantiated, then the e.declarations have been // validated within the chk instance. - chkOptsCopy = append(chkOptsCopy, checker.ValidatedDeclarations(e.chk)) + chkOptsCopy = append(chkOptsCopy, checker.ValidatedDeclarations(chk)) } else { // If the type-checker has not been instantiated, ensure the unvalidated declarations are // provided to the extended Env instance. @@ -509,7 +511,7 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) { // Ensure that the checker init happens eagerly rather than lazily. if e.HasFeature(featureEagerlyValidateDeclarations) { - err := e.initChecker() + _, err := e.initChecker() if err != nil { return nil, err } @@ -518,7 +520,7 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) { return e, nil } -func (e *Env) initChecker() error { +func (e *Env) initChecker() (*checker.Env, error) { e.chkOnce.Do(func() { chkOpts := []checker.Option{} chkOpts = append(chkOpts, e.chkOpts...) @@ -530,32 +532,47 @@ func (e *Env) initChecker() error { ce, err := checker.NewEnv(e.Container, e.provider, chkOpts...) if err != nil { - e.chkErr = err + e.setCheckerOrError(nil, err) return } // Add the statically configured declarations. err = ce.Add(e.declarations...) if err != nil { - e.chkErr = err + e.setCheckerOrError(nil, err) return } // Add the function declarations which are derived from the FunctionDecl instances. for _, fn := range e.functions { fnDecl, err := functionDeclToExprDecl(fn) if err != nil { - e.chkErr = err + e.setCheckerOrError(nil, err) return } err = ce.Add(fnDecl) if err != nil { - e.chkErr = err + e.setCheckerOrError(nil, err) return } } // Add function declarations here separately. - e.chk = ce + e.setCheckerOrError(ce, nil) }) - return e.chkErr + return e.getCheckerOrError() +} + +// setCheckerOrError sets the checker.Env or error state in a concurrency-safe manner +func (e *Env) setCheckerOrError(chk *checker.Env, chkErr error) { + e.chkMutex.Lock() + e.chk = chk + e.chkErr = chkErr + e.chkMutex.Unlock() +} + +// getCheckerOrError gets the checker.Env or error state in a concurrency-safe manner +func (e *Env) getCheckerOrError() (*checker.Env, error) { + e.chkMutex.Lock() + defer e.chkMutex.Unlock() + return e.chk, e.chkErr } // maybeApplyFeature determines whether the feature-guarded option is enabled, and if so applies diff --git a/cel/env_test.go b/cel/env_test.go index e708d4b2..5a454b93 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -15,7 +15,9 @@ package cel import ( + "fmt" "reflect" + "sync" "testing" "github.com/google/cel-go/common" @@ -82,6 +84,31 @@ ERROR: :1:2: Syntax error: mismatched input '' expecting {'[', '{', } } +func TestEnvCheckExtendRace(t *testing.T) { + t.Parallel() + for i := 0; i < 500; i++ { + wg := &sync.WaitGroup{} + wg.Add(2) + env, err := NewCustomEnv(StdLib()) + if err != nil { + t.Fatalf("NewCustomEnv() failed: %v", err) + } + t.Run(fmt.Sprintf("Compile[%d]", i), func(t *testing.T) { + go func() { + defer wg.Done() + _, _ = env.Compile(`1 + 1 * 20 < 400`) + }() + }) + t.Run(fmt.Sprintf("Extend[%d]", i), func(t *testing.T) { + go func() { + defer wg.Done() + _, _ = env.Extend(Variable("bar", BoolType)) + }() + }) + wg.Wait() + } +} + func BenchmarkNewCustomEnvLazy(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ {