Skip to content

Commit

Permalink
sql/parser: inject parameter types into aggregate constructors
Browse files Browse the repository at this point in the history
Fixes cockroachdb#12207.

This change adjusts aggregate constructors take their parameter types.
This is useful for when aggregate behavior needs to change depending on
the input parameters. For instance, the same constructor needs to be
used for `array_agg(string)` and `array_agg(name)`, but their runtime
behavior needs to be different.
  • Loading branch information
nvanbenschoten committed Jan 28, 2017
1 parent a252c7f commit eb43399
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 69 deletions.
5 changes: 4 additions & 1 deletion pkg/sql/distsqlrun/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ func GetAggregateInfo(
for _, t := range b.Types.Types() {
if inputDatumType.Equivalent(t) {
// Found!
return b.AggregateFunc, sqlbase.DatumTypeToColumnType(b.FixedReturnType()), nil
constructAgg := func() parser.AggregateFunc {
return b.AggregateFunc([]parser.Type{inputDatumType})
}
return constructAgg, sqlbase.DatumTypeToColumnType(b.FixedReturnType()), nil
}
}
}
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,7 @@ func checkResultType(typ parser.Type) error {
case parser.TypeTimestampTZ:
case parser.TypeInterval:
case parser.TypeStringArray:
case parser.TypeNameArray:
case parser.TypeIntArray:
default:
// Compare all types that cannot rely on == equality.
Expand Down
77 changes: 36 additions & 41 deletions pkg/sql/parser/aggregate_builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ type AggregateFunc interface {
// Exported for use in documentation.
var Aggregates = map[string][]Builtin{
"array_agg": {
makeAggBuiltin(TypeInt, TypeIntArray, newIntArrayAggregate,
"Aggregates the selected values into an array."),
makeAggBuiltin(
TypeString, TypeStringArray, newStringArrayAggregate,
makeIdentityArrayAggBuiltin(newArrayAggregate,
"Aggregates the selected values into an array."),
},

Expand Down Expand Up @@ -166,7 +163,7 @@ var Aggregates = map[string][]Builtin{
},
}

func makeAggBuiltin(in, ret Type, f func() AggregateFunc, info string) Builtin {
func makeAggBuiltin(in, ret Type, f func([]Type) AggregateFunc, info string) Builtin {
return Builtin{
// See the comment about aggregate functions in the definitions
// of the Builtins array above.
Expand All @@ -175,14 +172,14 @@ func makeAggBuiltin(in, ret Type, f func() AggregateFunc, info string) Builtin {
Types: ArgTypes{{"arg", in}},
ReturnType: fixedReturnType{ret},
AggregateFunc: f,
WindowFunc: func() WindowFunc {
return newAggregateWindow(f())
WindowFunc: func(params []Type) WindowFunc {
return newAggregateWindow(f(params))
},
Info: info,
}
}

func makeIdentityArrayAggBuiltin(f func() AggregateFunc, info string) Builtin {
func makeIdentityArrayAggBuiltin(f func([]Type) AggregateFunc, info string) Builtin {
b := makeAggBuiltin(TypeAny, nil, f, info)
b.ReturnType = identityArrayReturnType{0}
return b
Expand Down Expand Up @@ -240,12 +237,10 @@ type arrayAggregate struct {
arr *DArray
}

func newIntArrayAggregate() AggregateFunc {
return &arrayAggregate{arr: NewDArray(TypeInt)}
}

func newStringArrayAggregate() AggregateFunc {
return &arrayAggregate{arr: NewDArray(TypeString)}
func newArrayAggregate(params []Type) AggregateFunc {
return &arrayAggregate{
arr: NewDArray(params[0]),
}
}

// Add accumulates the passed datum into the array.
Expand All @@ -268,14 +263,14 @@ type avgAggregate struct {
count int
}

func newIntAvgAggregate() AggregateFunc {
return &avgAggregate{agg: newIntSumAggregate()}
func newIntAvgAggregate(params []Type) AggregateFunc {
return &avgAggregate{agg: newIntSumAggregate(params)}
}
func newFloatAvgAggregate() AggregateFunc {
return &avgAggregate{agg: newFloatSumAggregate()}
func newFloatAvgAggregate(params []Type) AggregateFunc {
return &avgAggregate{agg: newFloatSumAggregate(params)}
}
func newDecimalAvgAggregate() AggregateFunc {
return &avgAggregate{agg: newDecimalSumAggregate()}
func newDecimalAvgAggregate(params []Type) AggregateFunc {
return &avgAggregate{agg: newDecimalSumAggregate(params)}
}

// Add accumulates the passed datum into the average.
Expand Down Expand Up @@ -311,10 +306,10 @@ type concatAggregate struct {
result bytes.Buffer
}

func newBytesConcatAggregate() AggregateFunc {
func newBytesConcatAggregate(_ []Type) AggregateFunc {
return &concatAggregate{forBytes: true}
}
func newStringConcatAggregate() AggregateFunc {
func newStringConcatAggregate(_ []Type) AggregateFunc {
return &concatAggregate{}
}

Expand Down Expand Up @@ -349,7 +344,7 @@ type boolAndAggregate struct {
result bool
}

func newBoolAndAggregate() AggregateFunc {
func newBoolAndAggregate(_ []Type) AggregateFunc {
return &boolAndAggregate{}
}

Expand All @@ -376,7 +371,7 @@ type boolOrAggregate struct {
result bool
}

func newBoolOrAggregate() AggregateFunc {
func newBoolOrAggregate(_ []Type) AggregateFunc {
return &boolOrAggregate{}
}

Expand All @@ -399,7 +394,7 @@ type countAggregate struct {
count int
}

func newCountAggregate() AggregateFunc {
func newCountAggregate(_ []Type) AggregateFunc {
return &countAggregate{}
}

Expand All @@ -420,7 +415,7 @@ type MaxAggregate struct {
max Datum
}

func newMaxAggregate() AggregateFunc {
func newMaxAggregate(_ []Type) AggregateFunc {
return &MaxAggregate{}
}

Expand Down Expand Up @@ -452,7 +447,7 @@ type MinAggregate struct {
min Datum
}

func newMinAggregate() AggregateFunc {
func newMinAggregate(_ []Type) AggregateFunc {
return &MinAggregate{}
}

Expand Down Expand Up @@ -484,7 +479,7 @@ type smallIntSumAggregate struct {
seenNonNull bool
}

func newSmallIntSumAggregate() AggregateFunc {
func newSmallIntSumAggregate(_ []Type) AggregateFunc {
return &smallIntSumAggregate{}
}

Expand Down Expand Up @@ -517,7 +512,7 @@ type intSumAggregate struct {
seenNonNull bool
}

func newIntSumAggregate() AggregateFunc {
func newIntSumAggregate(_ []Type) AggregateFunc {
return &intSumAggregate{}
}

Expand Down Expand Up @@ -571,7 +566,7 @@ type decimalSumAggregate struct {
sawNonNull bool
}

func newDecimalSumAggregate() AggregateFunc {
func newDecimalSumAggregate(_ []Type) AggregateFunc {
return &decimalSumAggregate{}
}

Expand Down Expand Up @@ -600,7 +595,7 @@ type floatSumAggregate struct {
sawNonNull bool
}

func newFloatSumAggregate() AggregateFunc {
func newFloatSumAggregate(_ []Type) AggregateFunc {
return &floatSumAggregate{}
}

Expand All @@ -627,7 +622,7 @@ type intervalSumAggregate struct {
sawNonNull bool
}

func newIntervalSumAggregate() AggregateFunc {
func newIntervalSumAggregate(_ []Type) AggregateFunc {
return &intervalSumAggregate{}
}

Expand Down Expand Up @@ -655,7 +650,7 @@ type intVarianceAggregate struct {
tmpDec DDecimal
}

func newIntVarianceAggregate() AggregateFunc {
func newIntVarianceAggregate(_ []Type) AggregateFunc {
return &intVarianceAggregate{}
}

Expand All @@ -678,7 +673,7 @@ type floatVarianceAggregate struct {
sqrDiff float64
}

func newFloatVarianceAggregate() AggregateFunc {
func newFloatVarianceAggregate(_ []Type) AggregateFunc {
return &floatVarianceAggregate{}
}

Expand Down Expand Up @@ -715,7 +710,7 @@ type decimalVarianceAggregate struct {
tmp inf.Dec
}

func newDecimalVarianceAggregate() AggregateFunc {
func newDecimalVarianceAggregate(_ []Type) AggregateFunc {
return &decimalVarianceAggregate{}
}

Expand Down Expand Up @@ -756,14 +751,14 @@ type stdDevAggregate struct {
agg AggregateFunc
}

func newIntStdDevAggregate() AggregateFunc {
return &stdDevAggregate{agg: newIntVarianceAggregate()}
func newIntStdDevAggregate(params []Type) AggregateFunc {
return &stdDevAggregate{agg: newIntVarianceAggregate(params)}
}
func newFloatStdDevAggregate() AggregateFunc {
return &stdDevAggregate{agg: newFloatVarianceAggregate()}
func newFloatStdDevAggregate(params []Type) AggregateFunc {
return &stdDevAggregate{agg: newFloatVarianceAggregate(params)}
}
func newDecimalStdDevAggregate() AggregateFunc {
return &stdDevAggregate{agg: newDecimalVarianceAggregate()}
func newDecimalStdDevAggregate(params []Type) AggregateFunc {
return &stdDevAggregate{agg: newDecimalVarianceAggregate(params)}
}

// Add implements the AggregateFunc interface.
Expand Down
9 changes: 5 additions & 4 deletions pkg/sql/parser/aggregate_builtins_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import (
// printing all values to strings once the accumulation has finished. If the string
// slices are not equal, it means that the result Datums were modified during later
// accumulation, which violates the "deep copy of any internal state" condition.
func testAggregateResultDeepCopy(t *testing.T, aggFunc func() AggregateFunc, vals []Datum) {
aggImpl := aggFunc()
func testAggregateResultDeepCopy(t *testing.T, aggFunc func([]Type) AggregateFunc, vals []Datum) {
aggImpl := aggFunc([]Type{vals[0].ResolvedType()})
runningDatums := make([]Datum, len(vals))
runningStrings := make([]string, len(vals))
for i := range vals {
Expand Down Expand Up @@ -223,10 +223,11 @@ func makeIntervalTestDatum(count int) []Datum {
return vals
}

func runBenchmarkAggregate(b *testing.B, aggFunc func() AggregateFunc, vals []Datum) {
func runBenchmarkAggregate(b *testing.B, aggFunc func([]Type) AggregateFunc, vals []Datum) {
params := []Type{vals[0].ResolvedType()}
b.ResetTimer()
for i := 0; i < b.N; i++ {
aggImpl := aggFunc()
aggImpl := aggFunc(params)
for i := range vals {
aggImpl.Add(vals[i])
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/parser/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ type Builtin struct {
// might be more appropriate.
Info string

AggregateFunc func() AggregateFunc
WindowFunc func() WindowFunc
AggregateFunc func([]Type) AggregateFunc
WindowFunc func([]Type) WindowFunc
fn func(*EvalContext, DTuple) (Datum, error)
}

Expand Down
24 changes: 22 additions & 2 deletions pkg/sql/parser/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -890,13 +890,33 @@ type FuncExpr struct {
// GetAggregateConstructor exposes the AggregateFunc field for use by
// the group node in package sql.
func (node *FuncExpr) GetAggregateConstructor() func() AggregateFunc {
return node.fn.AggregateFunc
if node.fn.AggregateFunc == nil {
return nil
}
return func() AggregateFunc {
types := typesOfExprs(node.Exprs)
return node.fn.AggregateFunc(types)
}
}

// GetWindowConstructor returns a window function constructor if the
// FuncExpr is a built-in window function.
func (node *FuncExpr) GetWindowConstructor() func() WindowFunc {
return node.fn.WindowFunc
if node.fn.WindowFunc == nil {
return nil
}
return func() WindowFunc {
types := typesOfExprs(node.Exprs)
return node.fn.WindowFunc(types)
}
}

func typesOfExprs(exprs Exprs) []Type {
types := make([]Type, len(exprs))
for i, expr := range exprs {
types[i] = expr.(TypedExpr).ResolvedType()
}
return types
}

// IsWindowFunctionApplication returns if the function is being applied as a window function.
Expand Down
20 changes: 13 additions & 7 deletions pkg/sql/parser/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ var (
// TypeName is a type-alias for TypeString with a different OID. Can be
// compared with ==.
TypeName = wrapTypeWithOid(TypeString, oid.T_name)
// TypeNameArray is the type family of a DArray containing the Name alias type.
// Can be compared with ==.
TypeNameArray Type = tArray{TypeName}

// TypesAnyNonArray contains all non-array types.
TypesAnyNonArray = []Type{
Expand Down Expand Up @@ -453,16 +456,19 @@ func (tArray) Size() (uintptr, bool) {
return unsafe.Sizeof(DString("")), variableSize
}

// oidToArrayOid maps scalar type Oids to their corresponding array type Oid.
var oidToArrayOid = map[oid.Oid]oid.Oid{
oid.T_int8: oid.T__int8,
oid.T_text: oid.T__text,
oid.T_name: oid.T__name,
}

// Oid implements the Type interface.
func (a tArray) Oid() oid.Oid {
switch a.Typ {
case TypeInt:
return oid.T__int8
case TypeString:
return oid.T__text
default:
return oid.T_anyarray
if o, ok := oidToArrayOid[a.Typ.Oid()]; ok {
return o
}
return oid.T_anyarray
}

// SQLName implements the Type interface.
Expand Down
Loading

0 comments on commit eb43399

Please sign in to comment.