Skip to content

Commit

Permalink
Fix cost estimates to propagate result sizes. (#821)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpbetz authored Aug 23, 2023
1 parent 647ee3b commit f118dce
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
26 changes: 23 additions & 3 deletions checker/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -535,14 +535,34 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args

if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum())}
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
}
switch overloadID {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString:
if overloadID == overloads.ExtFormatString {
case overloads.ExtFormatString:
if target != nil {
// ResultSize not calculated because we can't bound the max size.
return CallEstimate{CostEstimate: c.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}
case overloads.StringToBytes:
if len(args) == 1 {
sz := c.sizeEstimate(args[0])
// ResultSize max is when each char converts to 4 bytes.
return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min, Max: sz.Max * 4}}
}
case overloads.BytesToString:
if len(args) == 1 {
sz := c.sizeEstimate(args[0])
// ResultSize min is when 4 bytes convert to 1 char.
return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min / 4, Max: sz.Max}}
}
case overloads.ExtQuoteString:
if len(args) == 1 {
sz := c.sizeEstimate(args[0])
// ResultSize max is when each char is escaped. 2 quote chars always added.
return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min + 2, Max: sz.Max*2 + 2}}
}
case overloads.StartsWithString, overloads.EndsWithString:
if len(args) == 1 {
return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}
Expand Down
17 changes: 16 additions & 1 deletion checker/cost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"github.com/google/cel-go/common/stdlib"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser"

proto3pb "github.com/google/cel-go/test/proto3pb"
)

Expand Down Expand Up @@ -261,13 +260,29 @@ func TestCost(t *testing.T) {
expr: `string(input)`,
wanted: CostEstimate{Min: 1, Max: 51},
},
{
name: "bytes to string conversion equality",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)},
hints: map[string]int64{"input": 500},
// equality check ensures that the resultSize calculation is included in cost
expr: `string(input) == string(input)`,
wanted: CostEstimate{Min: 3, Max: 152},
},
{
name: "string to bytes conversion",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)},
hints: map[string]int64{"input": 500},
expr: `bytes(input)`,
wanted: CostEstimate{Min: 1, Max: 51},
},
{
name: "string to bytes conversion equality",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)},
hints: map[string]int64{"input": 500},
// equality check ensures that the resultSize calculation is included in cost
expr: `bytes(input) == bytes(input)`,
wanted: CostEstimate{Min: 3, Max: 302},
},
{
name: "int to string conversion",
expr: `string(1)`,
Expand Down

0 comments on commit f118dce

Please sign in to comment.