From 5f84335fbb1e1467eed07563be00e9338a06ff03 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Mon, 22 Aug 2022 11:08:42 -0400 Subject: [PATCH] ARROW-17475: [Go] Function interface and Registry impl (#13924) Authored-by: Matt Topol Signed-off-by: Matt Topol --- dev/release/rat_exclude_files.txt | 1 + go/arrow/compute/doc.go | 29 ++++ go/arrow/compute/expression.go | 4 + go/arrow/compute/funckind_string.go | 27 ++++ go/arrow/compute/functions.go | 63 +++++++++ go/arrow/compute/functions_test.go | 46 +++++++ go/arrow/compute/registry.go | 201 ++++++++++++++++++++++++++++ go/arrow/compute/registry_test.go | 176 ++++++++++++++++++++++++ 8 files changed, 547 insertions(+) create mode 100644 go/arrow/compute/doc.go create mode 100644 go/arrow/compute/funckind_string.go create mode 100644 go/arrow/compute/functions.go create mode 100644 go/arrow/compute/functions_test.go create mode 100644 go/arrow/compute/registry.go create mode 100644 go/arrow/compute/registry_test.go diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 5aaf64039add2..c9d0309425c4e 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -143,6 +143,7 @@ go/arrow/cdata/test/go.sum go/arrow/unionmode_string.go go/arrow/compute/go.sum go/arrow/compute/datumkind_string.go +go/arrow/compute/funckind_string.go go/*.tmpldata go/*.s go/parquet/internal/gen-go/parquet/GoUnusedProtection__.go diff --git a/go/arrow/compute/doc.go b/go/arrow/compute/doc.go new file mode 100644 index 0000000000000..ee19cd4f965f8 --- /dev/null +++ b/go/arrow/compute/doc.go @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package compute is a native-go implementation of an Acero-like +// arrow compute engine. +// +// While consumers of Arrow that are able to use CGO could utilize the +// C Data API (using the cdata package) and could link against the +// acero library directly, there are consumers who cannot use CGO. This +// is an attempt to provide for those users, and in general create a +// native-go arrow compute engine. +// +// Everything in this package should be considered Experimental for now. +package compute + +//go:generate stringer -type=FuncKind -linecomment diff --git a/go/arrow/compute/expression.go b/go/arrow/compute/expression.go index 2dd4ab626afdc..b42d3bc335fb2 100644 --- a/go/arrow/compute/expression.go +++ b/go/arrow/compute/expression.go @@ -450,6 +450,10 @@ type FunctionOptionsEqual interface { Equals(FunctionOptions) bool } +type FunctionOptionsCloneable interface { + Clone() FunctionOptions +} + type MakeStructOptions struct { FieldNames []string `compute:"field_names"` FieldNullability []bool `compute:"field_nullability"` diff --git a/go/arrow/compute/funckind_string.go b/go/arrow/compute/funckind_string.go new file mode 100644 index 0000000000000..97d3eaa031387 --- /dev/null +++ b/go/arrow/compute/funckind_string.go @@ -0,0 +1,27 @@ +// Code generated by "stringer -type=FuncKind -linecomment"; DO NOT EDIT. + +package compute + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[FuncScalar-0] + _ = x[FuncVector-1] + _ = x[FuncScalarAgg-2] + _ = x[FuncHashAgg-3] + _ = x[FuncMeta-4] +} + +const _FuncKind_name = "ScalarVectorScalarAggregateHashAggregateMeta" + +var _FuncKind_index = [...]uint8{0, 6, 12, 27, 40, 44} + +func (i FuncKind) String() string { + if i < 0 || i >= FuncKind(len(_FuncKind_index)-1) { + return "FuncKind(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _FuncKind_name[_FuncKind_index[i]:_FuncKind_index[i+1]] +} diff --git a/go/arrow/compute/functions.go b/go/arrow/compute/functions.go new file mode 100644 index 0000000000000..faa7981ca04e4 --- /dev/null +++ b/go/arrow/compute/functions.go @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute + +import ( + "context" +) + +type Function interface { + Name() string + Kind() FuncKind + Arity() Arity + Doc() FunctionDoc + NumKernels() int + Execute(context.Context, FunctionOptions, ...Datum) (Datum, error) + DefaultOptions() FunctionOptions + Validate() error +} + +type Arity struct { + NArgs int + IsVarArgs bool +} + +func Nullary() Arity { return Arity{0, false} } +func Unary() Arity { return Arity{1, false} } +func Binary() Arity { return Arity{2, false} } +func Ternary() Arity { return Arity{3, false} } +func VarArgs(minArgs int) Arity { return Arity{minArgs, true} } + +type FunctionDoc struct { + Summary string + Description string + ArgNames []string + OptionsClass string + OptionsRequired bool +} + +var EmptyFuncDoc FunctionDoc + +type FuncKind int8 + +const ( + FuncScalar FuncKind = iota // Scalar + FuncVector // Vector + FuncScalarAgg // ScalarAggregate + FuncHashAgg // HashAggregate + FuncMeta // Meta +) diff --git a/go/arrow/compute/functions_test.go b/go/arrow/compute/functions_test.go new file mode 100644 index 0000000000000..78dbd8be5e4f1 --- /dev/null +++ b/go/arrow/compute/functions_test.go @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute_test + +import ( + "testing" + + "github.com/apache/arrow/go/v10/arrow/compute" + "github.com/stretchr/testify/assert" +) + +func TestArityBasics(t *testing.T) { + nullary := compute.Nullary() + assert.Equal(t, 0, nullary.NArgs) + assert.False(t, nullary.IsVarArgs) + + unary := compute.Unary() + assert.Equal(t, 1, unary.NArgs) + assert.False(t, unary.IsVarArgs) + + binary := compute.Binary() + assert.Equal(t, 2, binary.NArgs) + assert.False(t, binary.IsVarArgs) + + ternary := compute.Ternary() + assert.Equal(t, 3, ternary.NArgs) + assert.False(t, ternary.IsVarArgs) + + varargs := compute.VarArgs(2) + assert.Equal(t, 2, varargs.NArgs) + assert.True(t, varargs.IsVarArgs) +} diff --git a/go/arrow/compute/registry.go b/go/arrow/compute/registry.go new file mode 100644 index 0000000000000..b749cd9d0e6f3 --- /dev/null +++ b/go/arrow/compute/registry.go @@ -0,0 +1,201 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute + +import ( + "sync" + + "github.com/apache/arrow/go/v10/arrow/internal/debug" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" +) + +type FunctionRegistry interface { + CanAddFunction(fn Function, allowOverwrite bool) bool + AddFunction(fn Function, allowOverwrite bool) bool + CanAddAlias(target, source string) bool + AddAlias(target, source string) bool + GetFunction(name string) (Function, bool) + GetFunctionNames() []string + NumFunctions() int + + canAddFuncName(string, bool) bool +} + +var ( + registry FunctionRegistry + once sync.Once +) + +func GetFunctionRegistry() FunctionRegistry { + once.Do(func() { + registry = NewRegistry() + // initialize the others + }) + return registry +} + +func NewRegistry() FunctionRegistry { + return &funcRegistry{ + nameToFunction: make(map[string]Function)} +} + +func NewChildRegistry(parent FunctionRegistry) FunctionRegistry { + return &funcRegistry{ + parent: parent.(*funcRegistry), + nameToFunction: make(map[string]Function)} +} + +type funcRegistry struct { + parent *funcRegistry + + mx sync.RWMutex + nameToFunction map[string]Function +} + +func (reg *funcRegistry) getLocker(add bool) sync.Locker { + if add { + return ®.mx + } + return reg.mx.RLocker() +} + +func (reg *funcRegistry) CanAddFunction(fn Function, allowOverwrite bool) bool { + if reg.parent != nil && !reg.parent.CanAddFunction(fn, allowOverwrite) { + return false + } + + return reg.doAddFunction(fn, allowOverwrite, false) +} + +func (reg *funcRegistry) AddFunction(fn Function, allowOverwrite bool) bool { + if reg.parent != nil && !reg.parent.CanAddFunction(fn, allowOverwrite) { + return false + } + + return reg.doAddFunction(fn, allowOverwrite, true) +} + +func (reg *funcRegistry) CanAddAlias(target, source string) bool { + if reg.parent != nil && !reg.parent.canAddFuncName(target, false) { + return false + } + return reg.doAddAlias(target, source, false) +} + +func (reg *funcRegistry) AddAlias(target, source string) bool { + if reg.parent != nil && !reg.parent.canAddFuncName(target, false) { + return false + } + + return reg.doAddAlias(target, source, true) +} + +func (reg *funcRegistry) GetFunction(name string) (Function, bool) { + reg.mx.RLock() + defer reg.mx.RUnlock() + + if fn, ok := reg.nameToFunction[name]; ok { + return fn, ok + } + + if reg.parent != nil { + return reg.parent.GetFunction(name) + } + + return nil, false +} + +func (reg *funcRegistry) GetFunctionNames() (out []string) { + if reg.parent != nil { + out = reg.parent.GetFunctionNames() + } else { + out = make([]string, 0, len(reg.nameToFunction)) + } + reg.mx.RLock() + defer reg.mx.RUnlock() + + out = append(out, maps.Keys(reg.nameToFunction)...) + slices.Sort(out) + return +} + +func (reg *funcRegistry) NumFunctions() (n int) { + if reg.parent != nil { + n = reg.parent.NumFunctions() + } + reg.mx.RLock() + defer reg.mx.RUnlock() + return n + len(reg.nameToFunction) +} + +func (reg *funcRegistry) canAddFuncName(name string, allowOverwrite bool) bool { + if reg.parent != nil { + reg.parent.mx.RLock() + defer reg.parent.mx.RUnlock() + + if !reg.parent.canAddFuncName(name, allowOverwrite) { + return false + } + } + if !allowOverwrite { + _, ok := reg.nameToFunction[name] + return !ok + } + return true +} + +func (reg *funcRegistry) doAddFunction(fn Function, allowOverwrite bool, add bool) bool { + debug.Assert(fn.Validate() == nil, "invalid function") + + lk := reg.getLocker(add) + lk.Lock() + defer lk.Unlock() + + name := fn.Name() + if !reg.canAddFuncName(name, allowOverwrite) { + return false + } + + if add { + reg.nameToFunction[name] = fn + } + return true +} + +func (reg *funcRegistry) doAddAlias(target, source string, add bool) bool { + // source name must exist in the registry or the parent + // check outside the mutex, in case GetFunction has a mutex + // acquisition + fn, ok := reg.GetFunction(source) + if !ok { + return false + } + + lk := reg.getLocker(add) + lk.Lock() + defer lk.Unlock() + + if !reg.canAddFuncName(target, false) { + return false + } + + if add { + reg.nameToFunction[target] = fn + } + return true +} diff --git a/go/arrow/compute/registry_test.go b/go/arrow/compute/registry_test.go new file mode 100644 index 0000000000000..01f9b07949426 --- /dev/null +++ b/go/arrow/compute/registry_test.go @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute_test + +import ( + "context" + "errors" + "testing" + + "github.com/apache/arrow/go/v10/arrow/compute" + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" +) + +var registry compute.FunctionRegistry + +func init() { + // make tests fail if there's a problem initializing the global + // function registry + registry = compute.GetFunctionRegistry() +} + +type mockFn struct { + name string +} + +func (m *mockFn) Name() string { return m.name } +func (*mockFn) Kind() compute.FuncKind { return compute.FuncScalar } +func (*mockFn) Arity() compute.Arity { return compute.Unary() } +func (*mockFn) Doc() compute.FunctionDoc { return compute.EmptyFuncDoc } +func (*mockFn) NumKernels() int { return 0 } +func (*mockFn) Execute(context.Context, compute.FunctionOptions, ...compute.Datum) (compute.Datum, error) { + return nil, errors.New("not implemented") +} +func (*mockFn) DefaultOptions() compute.FunctionOptions { return nil } +func (*mockFn) Validate() error { return nil } + +func TestRegistryBasics(t *testing.T) { + tests := []struct { + name string + factory func() compute.FunctionRegistry + nfuncs int + expectedNames []string + }{ + {"default", compute.NewRegistry, 0, []string{}}, + {"nested", func() compute.FunctionRegistry { + return compute.NewChildRegistry(registry) + }, registry.NumFunctions(), registry.GetFunctionNames()}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + registry := tt.factory() + assert.Equal(t, tt.nfuncs, registry.NumFunctions()) + + fn := &mockFn{name: "f1"} + assert.True(t, registry.AddFunction(fn, false)) + assert.Equal(t, tt.nfuncs+1, registry.NumFunctions()) + + f1, ok := registry.GetFunction("f1") + assert.True(t, ok) + assert.Same(t, fn, f1) + + // non-existent + _, ok = registry.GetFunction("f2") + assert.False(t, ok) + + // name collision + f2 := &mockFn{name: "f1"} + assert.False(t, registry.AddFunction(f2, false)) + + // allow overwriting + assert.True(t, registry.AddFunction(f2, true)) + f1, ok = registry.GetFunction("f1") + assert.True(t, ok) + assert.Same(t, f2, f1) + + expected := append(tt.expectedNames, "f1") + slices.Sort(expected) + assert.Equal(t, expected, registry.GetFunctionNames()) + + // aliases + assert.False(t, registry.AddAlias("f33", "f3")) // doesn't exist + assert.True(t, registry.AddAlias("f11", "f1")) + f1, ok = registry.GetFunction("f11") + assert.True(t, ok) + assert.Same(t, f2, f1) + }) + } +} + +func TestRegistry(t *testing.T) { + defaultRegistry := registry + t.Run("RegisterTempFunctions", func(t *testing.T) { + const rounds = 3 + for i := 0; i < rounds; i++ { + registry := compute.NewChildRegistry(registry) + for _, v := range []string{"f1", "f2"} { + fn := &mockFn{name: v} + assert.True(t, registry.CanAddFunction(fn, false)) + assert.True(t, registry.AddFunction(fn, false)) + assert.False(t, registry.CanAddFunction(fn, false)) + assert.False(t, registry.AddFunction(fn, false)) + assert.True(t, defaultRegistry.CanAddFunction(fn, false)) + } + } + }) + + t.Run("RegisterTempAliases", func(t *testing.T) { + funcNames := defaultRegistry.GetFunctionNames() + const rounds = 3 + for i := 0; i < rounds; i++ { + registry := compute.NewChildRegistry(registry) + for _, funcName := range funcNames { + alias := "alias_of_" + funcName + _, ok := registry.GetFunction(alias) + assert.False(t, ok) + assert.True(t, registry.CanAddAlias(alias, funcName)) + assert.True(t, registry.AddAlias(alias, funcName)) + _, ok = registry.GetFunction(alias) + assert.True(t, ok) + _, ok = defaultRegistry.GetFunction(funcName) + assert.True(t, ok) + _, ok = defaultRegistry.GetFunction(alias) + assert.False(t, ok) + } + } + }) +} + +func TestRegistryRegisterNestedFunction(t *testing.T) { + defaultRegistry := registry + func1 := &mockFn{name: "f1"} + func2 := &mockFn{name: "f2"} + + const rounds = 3 + for i := 0; i < rounds; i++ { + registry1 := compute.NewChildRegistry(defaultRegistry) + + assert.True(t, registry1.CanAddFunction(func1, false)) + assert.True(t, registry1.AddFunction(func1, false)) + for j := 0; j < rounds; j++ { + registry2 := compute.NewChildRegistry(registry1) + assert.False(t, registry2.CanAddFunction(func1, false)) + assert.False(t, registry2.AddFunction(func1, false)) + + assert.True(t, registry2.CanAddFunction(func2, false)) + assert.True(t, registry2.AddFunction(func2, false)) + assert.False(t, registry2.CanAddFunction(func2, false)) + assert.False(t, registry2.AddFunction(func2, false)) + assert.True(t, defaultRegistry.CanAddFunction(func2, false)) + + assert.False(t, registry2.CanAddAlias("f1", "f2")) + assert.False(t, registry2.AddAlias("f1", "f2")) + assert.False(t, registry2.AddAlias("f1", "f1")) + } + assert.False(t, registry1.CanAddFunction(func1, false)) + assert.False(t, registry1.AddFunction(func1, false)) + assert.True(t, registry1.CanAddAlias("f2", "f1")) + assert.True(t, defaultRegistry.CanAddFunction(func1, false)) + } +}