Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql/parser: Define overload returnType interface and cleanup builtins #13209

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pkg/sql/distsqlrun/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ func GetAggregateInfo(
for _, t := range b.Types.Types() {
if inputDatumType.Equivalent(t) {
// Found!
return b.AggregateFunc, sqlbase.DatumTypeToColumnType(b.ReturnType), nil
constructAgg := func() parser.AggregateFunc {
return b.AggregateFunc([]parser.Type{inputDatumType})
}
return constructAgg, sqlbase.DatumTypeToColumnType(b.FixedReturnType()), nil
}
}
}
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1482,6 +1482,7 @@ func checkResultType(typ parser.Type) error {
case parser.TypeTimestampTZ:
case parser.TypeInterval:
case parser.TypeStringArray:
case parser.TypeNameArray:
case parser.TypeIntArray:
default:
// Compare all types that cannot rely on == equality.
Expand Down
132 changes: 64 additions & 68 deletions pkg/sql/parser/aggregate_builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,17 @@ type AggregateFunc interface {
// Exported for use in documentation.
var Aggregates = map[string][]Builtin{
"array_agg": {
makeAggBuiltin(TypeInt, TypeIntArray, newIntArrayAggregate,
"Aggregates the selected values into an array."),
makeAggBuiltin(
TypeString, TypeStringArray, newStringArrayAggregate,
"Aggregates the selected values into an array."),
makeAggBuiltinWithReturnType(
TypeAny,
func(args []TypedExpr) Type {
if len(args) == 0 {
return unknownReturnType
}
return tArray{args[0].ResolvedType()}
},
newArrayAggregate,
"Aggregates the selected values into an array.",
),
},

"avg": {
Expand Down Expand Up @@ -117,14 +123,19 @@ var Aggregates = map[string][]Builtin{
"Concatenates all selected values."),
},

"count": countImpls(),
"count": {
makeAggBuiltin(TypeAny, TypeInt, newCountAggregate,
"Calculates the number of selected elements."),
},

"max": makeAggBuiltins(newMaxAggregate, "Identifies the maximum selected value.",
TypeBool, TypeInt, TypeFloat, TypeDecimal, TypeString, TypeBytes,
TypeDate, TypeTimestamp, TypeTimestampTZ, TypeInterval),
"min": makeAggBuiltins(newMinAggregate, "Identifies the minimum selected value.",
TypeBool, TypeInt, TypeFloat, TypeDecimal, TypeString, TypeBytes,
TypeDate, TypeTimestamp, TypeTimestampTZ, TypeInterval),
"max": collectBuiltins(func(t Type) Builtin {
return makeAggBuiltin(t, t, newMaxAggregate,
"Identifies the maximum selected value.")
}, TypesAnyNonArray...),
"min": collectBuiltins(func(t Type) Builtin {
return makeAggBuiltin(t, t, newMinAggregate,
"Identifies the minimum selected value.")
}, TypesAnyNonArray...),

"sum_int": {
makeAggBuiltin(TypeInt, TypeInt, newSmallIntSumAggregate,
Expand Down Expand Up @@ -161,41 +172,28 @@ var Aggregates = map[string][]Builtin{
},
}

func makeAggBuiltin(in, ret Type, f func() AggregateFunc, info string) Builtin {
func makeAggBuiltin(in, ret Type, f func([]Type) AggregateFunc, info string) Builtin {
return makeAggBuiltinWithReturnType(in, fixedReturnType(ret), f, info)
}

func makeAggBuiltinWithReturnType(
in Type, retType returnTyper, f func([]Type) AggregateFunc, info string,
) Builtin {
return Builtin{
// See the comment about aggregate functions in the definitions
// of the Builtins array above.
impure: true,
class: AggregateClass,
Types: ArgTypes{{"arg", in}},
ReturnType: ret,
ReturnType: retType,
AggregateFunc: f,
WindowFunc: func() WindowFunc {
return newAggregateWindow(f())
WindowFunc: func(params []Type) WindowFunc {
return newAggregateWindow(f(params))
},
Info: info,
}
}

func makeAggBuiltins(f func() AggregateFunc, info string, types ...Type) []Builtin {
ret := make([]Builtin, len(types))
for i := range types {
ret[i] = makeAggBuiltin(types[i], types[i], f, info)
}
return ret
}

func countImpls() []Builtin {
types := []Type{TypeBool, TypeInt, TypeFloat, TypeDecimal, TypeString, TypeBytes,
TypeDate, TypeTimestamp, TypeTimestampTZ, TypeInterval, TypeTuple}
r := make([]Builtin, len(types))
for i := range types {
r[i] = makeAggBuiltin(types[i], TypeInt, newCountAggregate,
"Calculates the number of selected elements.")
}
return r
}

var _ AggregateFunc = &arrayAggregate{}
var _ AggregateFunc = &avgAggregate{}
var _ AggregateFunc = &countAggregate{}
Expand Down Expand Up @@ -248,12 +246,10 @@ type arrayAggregate struct {
arr *DArray
}

func newIntArrayAggregate() AggregateFunc {
return &arrayAggregate{arr: NewDArray(TypeInt)}
}

func newStringArrayAggregate() AggregateFunc {
return &arrayAggregate{arr: NewDArray(TypeString)}
func newArrayAggregate(params []Type) AggregateFunc {
return &arrayAggregate{
arr: NewDArray(params[0]),
}
}

// Add accumulates the passed datum into the array.
Expand All @@ -276,14 +272,14 @@ type avgAggregate struct {
count int
}

func newIntAvgAggregate() AggregateFunc {
return &avgAggregate{agg: newIntSumAggregate()}
func newIntAvgAggregate(params []Type) AggregateFunc {
return &avgAggregate{agg: newIntSumAggregate(params)}
}
func newFloatAvgAggregate() AggregateFunc {
return &avgAggregate{agg: newFloatSumAggregate()}
func newFloatAvgAggregate(params []Type) AggregateFunc {
return &avgAggregate{agg: newFloatSumAggregate(params)}
}
func newDecimalAvgAggregate() AggregateFunc {
return &avgAggregate{agg: newDecimalSumAggregate()}
func newDecimalAvgAggregate(params []Type) AggregateFunc {
return &avgAggregate{agg: newDecimalSumAggregate(params)}
}

// Add accumulates the passed datum into the average.
Expand Down Expand Up @@ -319,10 +315,10 @@ type concatAggregate struct {
result bytes.Buffer
}

func newBytesConcatAggregate() AggregateFunc {
func newBytesConcatAggregate(_ []Type) AggregateFunc {
return &concatAggregate{forBytes: true}
}
func newStringConcatAggregate() AggregateFunc {
func newStringConcatAggregate(_ []Type) AggregateFunc {
return &concatAggregate{}
}

Expand Down Expand Up @@ -357,7 +353,7 @@ type boolAndAggregate struct {
result bool
}

func newBoolAndAggregate() AggregateFunc {
func newBoolAndAggregate(_ []Type) AggregateFunc {
return &boolAndAggregate{}
}

Expand All @@ -384,7 +380,7 @@ type boolOrAggregate struct {
result bool
}

func newBoolOrAggregate() AggregateFunc {
func newBoolOrAggregate(_ []Type) AggregateFunc {
return &boolOrAggregate{}
}

Expand All @@ -407,7 +403,7 @@ type countAggregate struct {
count int
}

func newCountAggregate() AggregateFunc {
func newCountAggregate(_ []Type) AggregateFunc {
return &countAggregate{}
}

Expand All @@ -428,7 +424,7 @@ type MaxAggregate struct {
max Datum
}

func newMaxAggregate() AggregateFunc {
func newMaxAggregate(_ []Type) AggregateFunc {
return &MaxAggregate{}
}

Expand Down Expand Up @@ -460,7 +456,7 @@ type MinAggregate struct {
min Datum
}

func newMinAggregate() AggregateFunc {
func newMinAggregate(_ []Type) AggregateFunc {
return &MinAggregate{}
}

Expand Down Expand Up @@ -492,7 +488,7 @@ type smallIntSumAggregate struct {
seenNonNull bool
}

func newSmallIntSumAggregate() AggregateFunc {
func newSmallIntSumAggregate(_ []Type) AggregateFunc {
return &smallIntSumAggregate{}
}

Expand Down Expand Up @@ -525,7 +521,7 @@ type intSumAggregate struct {
seenNonNull bool
}

func newIntSumAggregate() AggregateFunc {
func newIntSumAggregate(_ []Type) AggregateFunc {
return &intSumAggregate{}
}

Expand Down Expand Up @@ -579,7 +575,7 @@ type decimalSumAggregate struct {
sawNonNull bool
}

func newDecimalSumAggregate() AggregateFunc {
func newDecimalSumAggregate(_ []Type) AggregateFunc {
return &decimalSumAggregate{}
}

Expand Down Expand Up @@ -608,7 +604,7 @@ type floatSumAggregate struct {
sawNonNull bool
}

func newFloatSumAggregate() AggregateFunc {
func newFloatSumAggregate(_ []Type) AggregateFunc {
return &floatSumAggregate{}
}

Expand All @@ -635,7 +631,7 @@ type intervalSumAggregate struct {
sawNonNull bool
}

func newIntervalSumAggregate() AggregateFunc {
func newIntervalSumAggregate(_ []Type) AggregateFunc {
return &intervalSumAggregate{}
}

Expand Down Expand Up @@ -663,7 +659,7 @@ type intVarianceAggregate struct {
tmpDec DDecimal
}

func newIntVarianceAggregate() AggregateFunc {
func newIntVarianceAggregate(_ []Type) AggregateFunc {
return &intVarianceAggregate{}
}

Expand All @@ -686,7 +682,7 @@ type floatVarianceAggregate struct {
sqrDiff float64
}

func newFloatVarianceAggregate() AggregateFunc {
func newFloatVarianceAggregate(_ []Type) AggregateFunc {
return &floatVarianceAggregate{}
}

Expand Down Expand Up @@ -723,7 +719,7 @@ type decimalVarianceAggregate struct {
tmp inf.Dec
}

func newDecimalVarianceAggregate() AggregateFunc {
func newDecimalVarianceAggregate(_ []Type) AggregateFunc {
return &decimalVarianceAggregate{}
}

Expand Down Expand Up @@ -764,14 +760,14 @@ type stdDevAggregate struct {
agg AggregateFunc
}

func newIntStdDevAggregate() AggregateFunc {
return &stdDevAggregate{agg: newIntVarianceAggregate()}
func newIntStdDevAggregate(params []Type) AggregateFunc {
return &stdDevAggregate{agg: newIntVarianceAggregate(params)}
}
func newFloatStdDevAggregate() AggregateFunc {
return &stdDevAggregate{agg: newFloatVarianceAggregate()}
func newFloatStdDevAggregate(params []Type) AggregateFunc {
return &stdDevAggregate{agg: newFloatVarianceAggregate(params)}
}
func newDecimalStdDevAggregate() AggregateFunc {
return &stdDevAggregate{agg: newDecimalVarianceAggregate()}
func newDecimalStdDevAggregate(params []Type) AggregateFunc {
return &stdDevAggregate{agg: newDecimalVarianceAggregate(params)}
}

// Add implements the AggregateFunc interface.
Expand Down
9 changes: 5 additions & 4 deletions pkg/sql/parser/aggregate_builtins_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import (
// printing all values to strings once the accumulation has finished. If the string
// slices are not equal, it means that the result Datums were modified during later
// accumulation, which violates the "deep copy of any internal state" condition.
func testAggregateResultDeepCopy(t *testing.T, aggFunc func() AggregateFunc, vals []Datum) {
aggImpl := aggFunc()
func testAggregateResultDeepCopy(t *testing.T, aggFunc func([]Type) AggregateFunc, vals []Datum) {
aggImpl := aggFunc([]Type{vals[0].ResolvedType()})
runningDatums := make([]Datum, len(vals))
runningStrings := make([]string, len(vals))
for i := range vals {
Expand Down Expand Up @@ -223,10 +223,11 @@ func makeIntervalTestDatum(count int) []Datum {
return vals
}

func runBenchmarkAggregate(b *testing.B, aggFunc func() AggregateFunc, vals []Datum) {
func runBenchmarkAggregate(b *testing.B, aggFunc func([]Type) AggregateFunc, vals []Datum) {
params := []Type{vals[0].ResolvedType()}
b.ResetTimer()
for i := 0; i < b.N; i++ {
aggImpl := aggFunc()
aggImpl := aggFunc(params)
for i := range vals {
aggImpl.Add(vals[i])
}
Expand Down
Loading