From 64549173a6e589c3153c7058be17447fcd7d4618 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 19 Aug 2022 11:18:58 -0400 Subject: [PATCH 1/3] ARROW-17475: [Go] Function interface and Registry impl --- dev/release/rat_exclude_files.txt | 1 + go/arrow/compute/doc.go | 29 +++ go/arrow/compute/funckind_string.go | 27 +++ go/arrow/compute/functions.go | 63 +++++++ go/arrow/compute/functions_test.go | 46 +++++ go/arrow/compute/go.mod | 3 +- go/arrow/compute/go.sum | 44 ++++- go/arrow/compute/registry.go | 266 ++++++++++++++++++++++++++++ go/arrow/compute/registry_test.go | 236 ++++++++++++++++++++++++ 9 files changed, 712 insertions(+), 3 deletions(-) create mode 100644 go/arrow/compute/doc.go create mode 100644 go/arrow/compute/funckind_string.go create mode 100644 go/arrow/compute/functions.go create mode 100644 go/arrow/compute/functions_test.go create mode 100644 go/arrow/compute/registry.go create mode 100644 go/arrow/compute/registry_test.go diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 316ff7b556402..5a0860aad323f 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -144,6 +144,7 @@ go/arrow/unionmode_string.go go/arrow/compute/go.sum go/arrow/compute/datumkind_string.go go/arrow/compute/valueshape_string.go +go/arrow/compute/funckind_string.go go/*.tmpldata go/*.s go/parquet/internal/gen-go/parquet/GoUnusedProtection__.go diff --git a/go/arrow/compute/doc.go b/go/arrow/compute/doc.go new file mode 100644 index 0000000000000..ee19cd4f965f8 --- /dev/null +++ b/go/arrow/compute/doc.go @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package compute is a native-go implementation of an Acero-like +// arrow compute engine. +// +// While consumers of Arrow that are able to use CGO could utilize the +// C Data API (using the cdata package) and could link against the +// acero library directly, there are consumers who cannot use CGO. This +// is an attempt to provide for those users, and in general create a +// native-go arrow compute engine. +// +// Everything in this package should be considered Experimental for now. +package compute + +//go:generate stringer -type=FuncKind -linecomment diff --git a/go/arrow/compute/funckind_string.go b/go/arrow/compute/funckind_string.go new file mode 100644 index 0000000000000..97d3eaa031387 --- /dev/null +++ b/go/arrow/compute/funckind_string.go @@ -0,0 +1,27 @@ +// Code generated by "stringer -type=FuncKind -linecomment"; DO NOT EDIT. + +package compute + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[FuncScalar-0] + _ = x[FuncVector-1] + _ = x[FuncScalarAgg-2] + _ = x[FuncHashAgg-3] + _ = x[FuncMeta-4] +} + +const _FuncKind_name = "ScalarVectorScalarAggregateHashAggregateMeta" + +var _FuncKind_index = [...]uint8{0, 6, 12, 27, 40, 44} + +func (i FuncKind) String() string { + if i < 0 || i >= FuncKind(len(_FuncKind_index)-1) { + return "FuncKind(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _FuncKind_name[_FuncKind_index[i]:_FuncKind_index[i+1]] +} diff --git a/go/arrow/compute/functions.go b/go/arrow/compute/functions.go new file mode 100644 index 0000000000000..faa7981ca04e4 --- /dev/null +++ b/go/arrow/compute/functions.go @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute + +import ( + "context" +) + +type Function interface { + Name() string + Kind() FuncKind + Arity() Arity + Doc() FunctionDoc + NumKernels() int + Execute(context.Context, FunctionOptions, ...Datum) (Datum, error) + DefaultOptions() FunctionOptions + Validate() error +} + +type Arity struct { + NArgs int + IsVarArgs bool +} + +func Nullary() Arity { return Arity{0, false} } +func Unary() Arity { return Arity{1, false} } +func Binary() Arity { return Arity{2, false} } +func Ternary() Arity { return Arity{3, false} } +func VarArgs(minArgs int) Arity { return Arity{minArgs, true} } + +type FunctionDoc struct { + Summary string + Description string + ArgNames []string + OptionsClass string + OptionsRequired bool +} + +var EmptyFuncDoc FunctionDoc + +type FuncKind int8 + +const ( + FuncScalar FuncKind = iota // Scalar + FuncVector // Vector + FuncScalarAgg // ScalarAggregate + FuncHashAgg // HashAggregate + FuncMeta // Meta +) diff --git a/go/arrow/compute/functions_test.go b/go/arrow/compute/functions_test.go new file mode 100644 index 0000000000000..78dbd8be5e4f1 --- /dev/null +++ b/go/arrow/compute/functions_test.go @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute_test + +import ( + "testing" + + "github.com/apache/arrow/go/v10/arrow/compute" + "github.com/stretchr/testify/assert" +) + +func TestArityBasics(t *testing.T) { + nullary := compute.Nullary() + assert.Equal(t, 0, nullary.NArgs) + assert.False(t, nullary.IsVarArgs) + + unary := compute.Unary() + assert.Equal(t, 1, unary.NArgs) + assert.False(t, unary.IsVarArgs) + + binary := compute.Binary() + assert.Equal(t, 2, binary.NArgs) + assert.False(t, binary.IsVarArgs) + + ternary := compute.Ternary() + assert.Equal(t, 3, ternary.NArgs) + assert.False(t, ternary.IsVarArgs) + + varargs := compute.VarArgs(2) + assert.Equal(t, 2, varargs.NArgs) + assert.True(t, varargs.IsVarArgs) +} diff --git a/go/arrow/compute/go.mod b/go/arrow/compute/go.mod index 58361f10c823d..447802c46b1d4 100644 --- a/go/arrow/compute/go.mod +++ b/go/arrow/compute/go.mod @@ -23,6 +23,7 @@ replace github.com/apache/arrow/go/v10 => ../../ require ( github.com/apache/arrow/go/v10 v10.0.0-00010101000000-000000000000 github.com/stretchr/testify v1.8.0 + golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f ) @@ -42,7 +43,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/sys v0.0.0-20220804214406-8e32c043e418 // indirect + golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 // indirect golang.org/x/tools v0.1.12 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go/arrow/compute/go.sum b/go/arrow/compute/go.sum index 173afed769b90..b05bdd419c7c4 100644 --- a/go/arrow/compute/go.sum +++ b/go/arrow/compute/go.sum @@ -31,6 +31,7 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -78,14 +79,17 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= @@ -98,6 +102,9 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= @@ -112,6 +119,7 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w= github.com/ruudk/golang-pdf417 v0.0.0-20201230142125-a7e3863a1245/go.mod h1:pQAZKsJ8yyVxGRWYNEm9oFB8ieLgKFnamEyDmSA0BRk= @@ -192,6 +200,7 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -199,11 +208,13 @@ golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210304124612-50617c2ba197/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220804214406-8e32c043e418 h1:9vYwv7OjYaky/tlAeD7C4oC9EsPTlaFl1H2jS++V+ME= -golang.org/x/sys v0.0.0-20220804214406-8e32c043e418/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= +golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -220,6 +231,7 @@ golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= @@ -277,4 +289,32 @@ gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= +lukechampine.com/uint128 v1.1.1/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +modernc.org/cc/v3 v3.36.0/go.mod h1:NFUHyPn4ekoC/JHeZFfZurN6ixxawE1BnVonP/oahEI= +modernc.org/cc/v3 v3.36.1/go.mod h1:NFUHyPn4ekoC/JHeZFfZurN6ixxawE1BnVonP/oahEI= +modernc.org/ccgo/v3 v3.0.0-20220428102840-41399a37e894/go.mod h1:eI31LL8EwEBKPpNpA4bU1/i+sKOwOrQy8D87zWUcRZc= +modernc.org/ccgo/v3 v3.0.0-20220430103911-bc99d88307be/go.mod h1:bwdAnOoaIt8Ax9YdWGjxWsdkPcZyRPHqrOvJxaKAKGw= +modernc.org/ccgo/v3 v3.16.4/go.mod h1:tGtX0gE9Jn7hdZFeU88slbTh1UtCYKusWOoCJuvkWsQ= +modernc.org/ccgo/v3 v3.16.6/go.mod h1:tGtX0gE9Jn7hdZFeU88slbTh1UtCYKusWOoCJuvkWsQ= +modernc.org/ccgo/v3 v3.16.8/go.mod h1:zNjwkizS+fIFDrDjIAgBSCLkWbJuHF+ar3QRn+Z9aws= +modernc.org/ccorpus v1.11.6/go.mod h1:2gEUTrWqdpH2pXsmTM1ZkjeSrUWDpjMu2T6m29L/ErQ= +modernc.org/httpfs v1.0.6/go.mod h1:7dosgurJGp0sPaRanU53W4xZYKh14wfzX420oZADeHM= +modernc.org/libc v0.0.0-20220428101251-2d5f3daf273b/go.mod h1:p7Mg4+koNjc8jkqwcoFBJx7tXkpj00G77X7A72jXPXA= +modernc.org/libc v1.16.0/go.mod h1:N4LD6DBE9cf+Dzf9buBlzVJndKr/iJHG97vGLHYnb5A= +modernc.org/libc v1.16.1/go.mod h1:JjJE0eu4yeK7tab2n4S1w8tlWd9MxXLRzheaRnAKymU= +modernc.org/libc v1.16.7/go.mod h1:hYIV5VZczAmGZAnG15Vdngn5HSF5cSkbvfz2B7GRuVU= +modernc.org/libc v1.16.17/go.mod h1:hYIV5VZczAmGZAnG15Vdngn5HSF5cSkbvfz2B7GRuVU= +modernc.org/libc v1.16.19/go.mod h1:p7Mg4+koNjc8jkqwcoFBJx7tXkpj00G77X7A72jXPXA= +modernc.org/mathutil v1.2.2/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/mathutil v1.4.1/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/memory v1.1.1/go.mod h1:/0wo5ibyrQiaoUoH7f9D8dnglAmILJ5/cxZlRECf+Nw= +modernc.org/opt v0.1.1/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sqlite v1.18.0/go.mod h1:B9fRWZacNxJBHoCJZQr1R54zhVn3fjfl0aszflrTSxY= +modernc.org/strutil v1.1.1/go.mod h1:DE+MQQ/hjKBZS2zNInV5hhcipt5rLPWkmpbGeW5mmdw= +modernc.org/strutil v1.1.2/go.mod h1:OYajnUAcI/MX+XD/Wx7v1bbdvcQSvxgtb0gC+u3d3eg= +modernc.org/tcl v1.13.1/go.mod h1:XOLfOwzhkljL4itZkK6T72ckMgvj0BDsnKNdZVUOecw= +modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +modernc.org/z v1.5.1/go.mod h1:eWFB510QWW5Th9YGZT81s+LwvaAs3Q2yr4sP0rmLkv8= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/go/arrow/compute/registry.go b/go/arrow/compute/registry.go new file mode 100644 index 0000000000000..c5638d07e0619 --- /dev/null +++ b/go/arrow/compute/registry.go @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute + +import ( + "sync" + + "github.com/apache/arrow/go/v10/arrow/internal/debug" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" +) + +type FunctionOptionsType interface { + TypeName() string + Compare(lhs, rhs FunctionOptions) bool + Copy(FunctionOptions) FunctionOptions +} + +type FunctionRegistry interface { + CanAddFunction(fn Function, allowOverwrite bool) bool + AddFunction(fn Function, allowOverwrite bool) bool + CanAddAlias(target, source string) bool + AddAlias(target, source string) bool + CanAddFunctionOptionsType(opts FunctionOptionsType, allowOverwrite bool) bool + AddFunctionOptionsType(opts FunctionOptionsType, allowOverwrite bool) bool + GetFunction(name string) (Function, bool) + GetFunctionNames() []string + GetFunctionOptionsType(name string) (FunctionOptionsType, bool) + NumFunctions() int + + canAddFuncName(string, bool) bool + canAddOptionsTypeName(name string, allowOverwrite bool) bool +} + +var ( + registry FunctionRegistry + once sync.Once +) + +func GetFunctionRegistry() FunctionRegistry { + once.Do(func() { + registry = NewRegistry() + // initialize the others + }) + return registry +} + +func NewRegistry() FunctionRegistry { + return &funcRegistry{ + nameToFunction: make(map[string]Function), + nameToOptsType: make(map[string]FunctionOptionsType)} +} + +func NewChildRegistry(parent FunctionRegistry) FunctionRegistry { + return &funcRegistry{ + parent: parent.(*funcRegistry), + nameToFunction: make(map[string]Function), + nameToOptsType: make(map[string]FunctionOptionsType)} +} + +type funcRegistry struct { + parent *funcRegistry + + mx sync.RWMutex + nameToFunction map[string]Function + nameToOptsType map[string]FunctionOptionsType +} + +func (reg *funcRegistry) getLocker(add bool) sync.Locker { + if add { + return ®.mx + } + return reg.mx.RLocker() +} + +func (reg *funcRegistry) CanAddFunction(fn Function, allowOverwrite bool) bool { + if reg.parent != nil && !reg.parent.CanAddFunction(fn, allowOverwrite) { + return false + } + + return reg.doAddFunction(fn, allowOverwrite, false) +} + +func (reg *funcRegistry) AddFunction(fn Function, allowOverwrite bool) bool { + if reg.parent != nil && !reg.parent.CanAddFunction(fn, allowOverwrite) { + return false + } + + return reg.doAddFunction(fn, allowOverwrite, true) +} + +func (reg *funcRegistry) CanAddAlias(target, source string) bool { + if reg.parent != nil && !reg.parent.canAddFuncName(target, false) { + return false + } + return reg.doAddAlias(target, source, false) +} + +func (reg *funcRegistry) AddAlias(target, source string) bool { + if reg.parent != nil && !reg.parent.canAddFuncName(target, false) { + return false + } + + return reg.doAddAlias(target, source, true) +} + +func (reg *funcRegistry) CanAddFunctionOptionsType(opts FunctionOptionsType, allowOverwrite bool) bool { + if reg.parent != nil && !reg.parent.CanAddFunctionOptionsType(opts, allowOverwrite) { + return false + } + + return reg.doAddFuncOptionsType(opts, allowOverwrite, false) +} + +func (reg *funcRegistry) AddFunctionOptionsType(opts FunctionOptionsType, allowOverwrite bool) bool { + if reg.parent != nil && !reg.parent.CanAddFunctionOptionsType(opts, allowOverwrite) { + return false + } + + return reg.doAddFuncOptionsType(opts, allowOverwrite, true) +} + +func (reg *funcRegistry) GetFunction(name string) (Function, bool) { + reg.mx.RLock() + defer reg.mx.RUnlock() + + if fn, ok := reg.nameToFunction[name]; ok { + return fn, ok + } + + if reg.parent != nil { + return reg.parent.GetFunction(name) + } + + return nil, false +} + +func (reg *funcRegistry) GetFunctionNames() (out []string) { + if reg.parent != nil { + out = reg.parent.GetFunctionNames() + } else { + out = make([]string, 0, len(reg.nameToFunction)) + } + reg.mx.RLock() + defer reg.mx.RUnlock() + + out = append(out, maps.Keys(reg.nameToFunction)...) + slices.Sort(out) + return +} + +func (reg *funcRegistry) GetFunctionOptionsType(name string) (FunctionOptionsType, bool) { + reg.mx.RLock() + defer reg.mx.RUnlock() + + opts, ok := reg.nameToOptsType[name] + if ok { + return opts, true + } + + if reg.parent != nil { + return reg.parent.GetFunctionOptionsType(name) + } + return nil, false +} + +func (reg *funcRegistry) NumFunctions() (n int) { + if reg.parent != nil { + n = reg.parent.NumFunctions() + } + return n + len(reg.nameToFunction) +} + +func (reg *funcRegistry) canAddFuncName(name string, allowOverwrite bool) bool { + if reg.parent != nil && !reg.parent.canAddFuncName(name, allowOverwrite) { + return false + } + if !allowOverwrite { + _, ok := reg.nameToFunction[name] + return !ok + } + return true +} + +func (reg *funcRegistry) doAddFunction(fn Function, allowOverwrite bool, add bool) bool { + debug.Assert(fn.Validate() == nil, "invalid function") + + lk := reg.getLocker(add) + lk.Lock() + defer lk.Unlock() + + name := fn.Name() + if !reg.canAddFuncName(name, allowOverwrite) { + return false + } + + if add { + reg.nameToFunction[name] = fn + } + return true +} + +func (reg *funcRegistry) doAddAlias(target, source string, add bool) bool { + // source name must exist in the registry or the parent + // check outside the mutex, in case GetFunction has a mutex + // acquisition + fn, ok := reg.GetFunction(source) + if !ok { + return false + } + + lk := reg.getLocker(add) + lk.Lock() + defer lk.Unlock() + + if !reg.canAddFuncName(target, false) { + return false + } + + if add { + reg.nameToFunction[target] = fn + } + return true +} + +func (reg *funcRegistry) canAddOptionsTypeName(name string, allowOverwrite bool) bool { + if reg.parent != nil && !reg.parent.canAddOptionsTypeName(name, allowOverwrite) { + return false + } + + if !allowOverwrite { + _, ok := reg.nameToOptsType[name] + return !ok + } + return true +} + +func (reg *funcRegistry) doAddFuncOptionsType(opts FunctionOptionsType, allowOverwrite, add bool) bool { + lk := reg.getLocker(add) + lk.Lock() + defer lk.Unlock() + + name := opts.TypeName() + if !reg.canAddOptionsTypeName(name, false) { + return false + } + + if add { + reg.nameToOptsType[name] = opts + } + return true +} diff --git a/go/arrow/compute/registry_test.go b/go/arrow/compute/registry_test.go new file mode 100644 index 0000000000000..5120d905d163a --- /dev/null +++ b/go/arrow/compute/registry_test.go @@ -0,0 +1,236 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute_test + +import ( + "context" + "errors" + "strconv" + "testing" + "unsafe" + + "github.com/apache/arrow/go/v10/arrow/compute" + "github.com/apache/arrow/go/v10/arrow/scalar" + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" +) + +var registry compute.FunctionRegistry + +func init() { + // make tests fail if there's a problem initializing the global + // function registry + registry = compute.GetFunctionRegistry() +} + +type mockFn struct { + name string +} + +func (m *mockFn) Name() string { return m.name } +func (*mockFn) Kind() compute.FuncKind { return compute.FuncScalar } +func (*mockFn) Arity() compute.Arity { return compute.Unary() } +func (*mockFn) Doc() compute.FunctionDoc { return compute.EmptyFuncDoc } +func (*mockFn) NumKernels() int { return 0 } +func (*mockFn) Execute(context.Context, compute.FunctionOptions, ...compute.Datum) (compute.Datum, error) { + return nil, errors.New("not implemented") +} +func (*mockFn) DefaultOptions() compute.FunctionOptions { return nil } +func (*mockFn) Validate() error { return nil } + +func TestRegistryBasics(t *testing.T) { + tests := []struct { + name string + factory func() compute.FunctionRegistry + nfuncs int + expectedNames []string + }{ + {"default", compute.NewRegistry, 0, []string{}}, + {"nested", func() compute.FunctionRegistry { + return compute.NewChildRegistry(registry) + }, registry.NumFunctions(), registry.GetFunctionNames()}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + registry := tt.factory() + assert.Equal(t, tt.nfuncs, registry.NumFunctions()) + + fn := &mockFn{name: "f1"} + assert.True(t, registry.AddFunction(fn, false)) + assert.Equal(t, tt.nfuncs+1, registry.NumFunctions()) + + f1, ok := registry.GetFunction("f1") + assert.True(t, ok) + assert.Same(t, fn, f1) + + // non-existent + _, ok = registry.GetFunction("f2") + assert.False(t, ok) + + // name collision + f2 := &mockFn{name: "f1"} + assert.False(t, registry.AddFunction(f2, false)) + + // allow overwriting + assert.True(t, registry.AddFunction(f2, true)) + f1, ok = registry.GetFunction("f1") + assert.True(t, ok) + assert.Same(t, f2, f1) + + expected := append(tt.expectedNames, "f1") + slices.Sort(expected) + assert.Equal(t, expected, registry.GetFunctionNames()) + + // aliases + assert.False(t, registry.AddAlias("f33", "f3")) // doesn't exist + assert.True(t, registry.AddAlias("f11", "f1")) + f1, ok = registry.GetFunction("f11") + assert.True(t, ok) + assert.Same(t, f2, f1) + }) + } +} + +func TestRegistry(t *testing.T) { + defaultRegistry := registry + t.Run("RegisterTempFunctions", func(t *testing.T) { + const rounds = 3 + for i := 0; i < rounds; i++ { + registry := compute.NewChildRegistry(registry) + for _, v := range []string{"f1", "f2"} { + fn := &mockFn{name: v} + assert.True(t, registry.CanAddFunction(fn, false)) + assert.True(t, registry.AddFunction(fn, false)) + assert.False(t, registry.CanAddFunction(fn, false)) + assert.False(t, registry.AddFunction(fn, false)) + assert.True(t, defaultRegistry.CanAddFunction(fn, false)) + } + } + }) + + t.Run("RegisterTempAliases", func(t *testing.T) { + funcNames := defaultRegistry.GetFunctionNames() + const rounds = 3 + for i := 0; i < rounds; i++ { + registry := compute.NewChildRegistry(registry) + for _, funcName := range funcNames { + alias := "alias_of_" + funcName + _, ok := registry.GetFunction(alias) + assert.False(t, ok) + assert.True(t, registry.CanAddAlias(alias, funcName)) + assert.True(t, registry.AddAlias(alias, funcName)) + _, ok = registry.GetFunction(alias) + assert.True(t, ok) + _, ok = defaultRegistry.GetFunction(funcName) + assert.True(t, ok) + _, ok = defaultRegistry.GetFunction(alias) + assert.False(t, ok) + } + } + }) +} + +type ExampleOptions[T int32 | uint64] struct { + Value scalar.Scalar +} + +func (ExampleOptions[T]) TypeName() string { return "example" } + +type ExampleOptionsType[T int32 | uint64] struct{} + +func (*ExampleOptionsType[T]) TypeName() string { + return "example" + strconv.Itoa(int(unsafe.Sizeof(T(0)))) +} + +func (*ExampleOptionsType[T]) Compare(lhs, rhs compute.FunctionOptions) bool { + return true +} + +func (*ExampleOptionsType[T]) Copy(opts compute.FunctionOptions) compute.FunctionOptions { + o := opts.(ExampleOptions[T]) + return ExampleOptions[T]{Value: o.Value} +} + +func TestRegistryTempFunctionOptionsType(t *testing.T) { + defaultRegistry := registry + optsTypes := []compute.FunctionOptionsType{ + &ExampleOptionsType[int32]{}, + &ExampleOptionsType[uint64]{}, + } + const rounds = 3 + for i := 0; i < rounds; i++ { + registry := compute.NewChildRegistry(defaultRegistry) + for _, opttype := range optsTypes { + assert.True(t, registry.CanAddFunctionOptionsType(opttype, false)) + assert.True(t, registry.AddFunctionOptionsType(opttype, false)) + assert.False(t, registry.CanAddFunctionOptionsType(opttype, false)) + assert.False(t, registry.AddFunctionOptionsType(opttype, false)) + assert.True(t, defaultRegistry.CanAddFunctionOptionsType(opttype, false)) + + opt, ok := registry.GetFunctionOptionsType(opttype.TypeName()) + assert.True(t, ok) + assert.Same(t, opttype, opt) + + _, ok = registry.GetFunctionOptionsType("foobar") + assert.False(t, ok) + } + } +} + +func TestRegistryRegisterNestedFunction(t *testing.T) { + defaultRegistry := registry + func1 := &mockFn{name: "f1"} + func2 := &mockFn{name: "f2"} + + optType1 := &ExampleOptionsType[int32]{} + optType2 := &ExampleOptionsType[uint64]{} + + const rounds = 3 + for i := 0; i < rounds; i++ { + registry1 := compute.NewChildRegistry(defaultRegistry) + + assert.True(t, registry1.CanAddFunction(func1, false)) + assert.True(t, registry1.AddFunction(func1, false)) + assert.True(t, registry1.CanAddFunctionOptionsType(optType1, false)) + assert.True(t, registry1.AddFunctionOptionsType(optType1, false)) + for j := 0; j < rounds; j++ { + registry2 := compute.NewChildRegistry(registry1) + assert.False(t, registry2.CanAddFunction(func1, false)) + assert.False(t, registry2.AddFunction(func1, false)) + assert.False(t, registry2.CanAddFunctionOptionsType(optType1, false)) + assert.False(t, registry2.AddFunctionOptionsType(optType1, false)) + + assert.True(t, registry2.CanAddFunction(func2, false)) + assert.True(t, registry2.AddFunction(func2, false)) + assert.False(t, registry2.CanAddFunction(func2, false)) + assert.False(t, registry2.AddFunction(func2, false)) + assert.True(t, registry2.CanAddFunctionOptionsType(optType2, false)) + assert.True(t, registry2.AddFunctionOptionsType(optType2, false)) + assert.True(t, defaultRegistry.CanAddFunction(func2, false)) + assert.True(t, defaultRegistry.CanAddFunctionOptionsType(optType2, false)) + + assert.False(t, registry2.CanAddAlias("f1", "f2")) + assert.False(t, registry2.AddAlias("f1", "f2")) + assert.False(t, registry2.AddAlias("f1", "f1")) + } + assert.False(t, registry1.CanAddFunction(func1, false)) + assert.False(t, registry1.AddFunction(func1, false)) + assert.True(t, registry1.CanAddAlias("f2", "f1")) + assert.True(t, defaultRegistry.CanAddFunction(func1, false)) + } +} From ad586388f789bf3ce8d356255ed41fecc6465de9 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 19 Aug 2022 15:20:35 -0400 Subject: [PATCH 2/3] update from feedback, remove FunctionOptionsType --- go/arrow/compute/expression.go | 4 ++ go/arrow/compute/registry.go | 76 +------------------------------ go/arrow/compute/registry_test.go | 60 ------------------------ 3 files changed, 6 insertions(+), 134 deletions(-) diff --git a/go/arrow/compute/expression.go b/go/arrow/compute/expression.go index 8e895fc0c21f6..37f643d2bd60d 100644 --- a/go/arrow/compute/expression.go +++ b/go/arrow/compute/expression.go @@ -463,6 +463,10 @@ type FunctionOptionsEqual interface { Equals(FunctionOptions) bool } +type FunctionOptionsCloneable interface { + Clone() FunctionOptions +} + type MakeStructOptions struct { FieldNames []string `compute:"field_names"` FieldNullability []bool `compute:"field_nullability"` diff --git a/go/arrow/compute/registry.go b/go/arrow/compute/registry.go index c5638d07e0619..f33f6991e39e6 100644 --- a/go/arrow/compute/registry.go +++ b/go/arrow/compute/registry.go @@ -24,26 +24,16 @@ import ( "golang.org/x/exp/slices" ) -type FunctionOptionsType interface { - TypeName() string - Compare(lhs, rhs FunctionOptions) bool - Copy(FunctionOptions) FunctionOptions -} - type FunctionRegistry interface { CanAddFunction(fn Function, allowOverwrite bool) bool AddFunction(fn Function, allowOverwrite bool) bool CanAddAlias(target, source string) bool AddAlias(target, source string) bool - CanAddFunctionOptionsType(opts FunctionOptionsType, allowOverwrite bool) bool - AddFunctionOptionsType(opts FunctionOptionsType, allowOverwrite bool) bool GetFunction(name string) (Function, bool) GetFunctionNames() []string - GetFunctionOptionsType(name string) (FunctionOptionsType, bool) NumFunctions() int canAddFuncName(string, bool) bool - canAddOptionsTypeName(name string, allowOverwrite bool) bool } var ( @@ -61,15 +51,13 @@ func GetFunctionRegistry() FunctionRegistry { func NewRegistry() FunctionRegistry { return &funcRegistry{ - nameToFunction: make(map[string]Function), - nameToOptsType: make(map[string]FunctionOptionsType)} + nameToFunction: make(map[string]Function)} } func NewChildRegistry(parent FunctionRegistry) FunctionRegistry { return &funcRegistry{ parent: parent.(*funcRegistry), - nameToFunction: make(map[string]Function), - nameToOptsType: make(map[string]FunctionOptionsType)} + nameToFunction: make(map[string]Function)} } type funcRegistry struct { @@ -77,7 +65,6 @@ type funcRegistry struct { mx sync.RWMutex nameToFunction map[string]Function - nameToOptsType map[string]FunctionOptionsType } func (reg *funcRegistry) getLocker(add bool) sync.Locker { @@ -118,22 +105,6 @@ func (reg *funcRegistry) AddAlias(target, source string) bool { return reg.doAddAlias(target, source, true) } -func (reg *funcRegistry) CanAddFunctionOptionsType(opts FunctionOptionsType, allowOverwrite bool) bool { - if reg.parent != nil && !reg.parent.CanAddFunctionOptionsType(opts, allowOverwrite) { - return false - } - - return reg.doAddFuncOptionsType(opts, allowOverwrite, false) -} - -func (reg *funcRegistry) AddFunctionOptionsType(opts FunctionOptionsType, allowOverwrite bool) bool { - if reg.parent != nil && !reg.parent.CanAddFunctionOptionsType(opts, allowOverwrite) { - return false - } - - return reg.doAddFuncOptionsType(opts, allowOverwrite, true) -} - func (reg *funcRegistry) GetFunction(name string) (Function, bool) { reg.mx.RLock() defer reg.mx.RUnlock() @@ -163,21 +134,6 @@ func (reg *funcRegistry) GetFunctionNames() (out []string) { return } -func (reg *funcRegistry) GetFunctionOptionsType(name string) (FunctionOptionsType, bool) { - reg.mx.RLock() - defer reg.mx.RUnlock() - - opts, ok := reg.nameToOptsType[name] - if ok { - return opts, true - } - - if reg.parent != nil { - return reg.parent.GetFunctionOptionsType(name) - } - return nil, false -} - func (reg *funcRegistry) NumFunctions() (n int) { if reg.parent != nil { n = reg.parent.NumFunctions() @@ -236,31 +192,3 @@ func (reg *funcRegistry) doAddAlias(target, source string, add bool) bool { } return true } - -func (reg *funcRegistry) canAddOptionsTypeName(name string, allowOverwrite bool) bool { - if reg.parent != nil && !reg.parent.canAddOptionsTypeName(name, allowOverwrite) { - return false - } - - if !allowOverwrite { - _, ok := reg.nameToOptsType[name] - return !ok - } - return true -} - -func (reg *funcRegistry) doAddFuncOptionsType(opts FunctionOptionsType, allowOverwrite, add bool) bool { - lk := reg.getLocker(add) - lk.Lock() - defer lk.Unlock() - - name := opts.TypeName() - if !reg.canAddOptionsTypeName(name, false) { - return false - } - - if add { - reg.nameToOptsType[name] = opts - } - return true -} diff --git a/go/arrow/compute/registry_test.go b/go/arrow/compute/registry_test.go index 5120d905d163a..01f9b07949426 100644 --- a/go/arrow/compute/registry_test.go +++ b/go/arrow/compute/registry_test.go @@ -19,12 +19,9 @@ package compute_test import ( "context" "errors" - "strconv" "testing" - "unsafe" "github.com/apache/arrow/go/v10/arrow/compute" - "github.com/apache/arrow/go/v10/arrow/scalar" "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" ) @@ -145,84 +142,27 @@ func TestRegistry(t *testing.T) { }) } -type ExampleOptions[T int32 | uint64] struct { - Value scalar.Scalar -} - -func (ExampleOptions[T]) TypeName() string { return "example" } - -type ExampleOptionsType[T int32 | uint64] struct{} - -func (*ExampleOptionsType[T]) TypeName() string { - return "example" + strconv.Itoa(int(unsafe.Sizeof(T(0)))) -} - -func (*ExampleOptionsType[T]) Compare(lhs, rhs compute.FunctionOptions) bool { - return true -} - -func (*ExampleOptionsType[T]) Copy(opts compute.FunctionOptions) compute.FunctionOptions { - o := opts.(ExampleOptions[T]) - return ExampleOptions[T]{Value: o.Value} -} - -func TestRegistryTempFunctionOptionsType(t *testing.T) { - defaultRegistry := registry - optsTypes := []compute.FunctionOptionsType{ - &ExampleOptionsType[int32]{}, - &ExampleOptionsType[uint64]{}, - } - const rounds = 3 - for i := 0; i < rounds; i++ { - registry := compute.NewChildRegistry(defaultRegistry) - for _, opttype := range optsTypes { - assert.True(t, registry.CanAddFunctionOptionsType(opttype, false)) - assert.True(t, registry.AddFunctionOptionsType(opttype, false)) - assert.False(t, registry.CanAddFunctionOptionsType(opttype, false)) - assert.False(t, registry.AddFunctionOptionsType(opttype, false)) - assert.True(t, defaultRegistry.CanAddFunctionOptionsType(opttype, false)) - - opt, ok := registry.GetFunctionOptionsType(opttype.TypeName()) - assert.True(t, ok) - assert.Same(t, opttype, opt) - - _, ok = registry.GetFunctionOptionsType("foobar") - assert.False(t, ok) - } - } -} - func TestRegistryRegisterNestedFunction(t *testing.T) { defaultRegistry := registry func1 := &mockFn{name: "f1"} func2 := &mockFn{name: "f2"} - optType1 := &ExampleOptionsType[int32]{} - optType2 := &ExampleOptionsType[uint64]{} - const rounds = 3 for i := 0; i < rounds; i++ { registry1 := compute.NewChildRegistry(defaultRegistry) assert.True(t, registry1.CanAddFunction(func1, false)) assert.True(t, registry1.AddFunction(func1, false)) - assert.True(t, registry1.CanAddFunctionOptionsType(optType1, false)) - assert.True(t, registry1.AddFunctionOptionsType(optType1, false)) for j := 0; j < rounds; j++ { registry2 := compute.NewChildRegistry(registry1) assert.False(t, registry2.CanAddFunction(func1, false)) assert.False(t, registry2.AddFunction(func1, false)) - assert.False(t, registry2.CanAddFunctionOptionsType(optType1, false)) - assert.False(t, registry2.AddFunctionOptionsType(optType1, false)) assert.True(t, registry2.CanAddFunction(func2, false)) assert.True(t, registry2.AddFunction(func2, false)) assert.False(t, registry2.CanAddFunction(func2, false)) assert.False(t, registry2.AddFunction(func2, false)) - assert.True(t, registry2.CanAddFunctionOptionsType(optType2, false)) - assert.True(t, registry2.AddFunctionOptionsType(optType2, false)) assert.True(t, defaultRegistry.CanAddFunction(func2, false)) - assert.True(t, defaultRegistry.CanAddFunctionOptionsType(optType2, false)) assert.False(t, registry2.CanAddAlias("f1", "f2")) assert.False(t, registry2.AddAlias("f1", "f2")) From 3cfacfc4249f12324c131ed92841220a0f4bd9d9 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Mon, 22 Aug 2022 11:08:29 -0400 Subject: [PATCH 3/3] add parent.mx.RLock to canAddFuncName --- go/arrow/compute/registry.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/go/arrow/compute/registry.go b/go/arrow/compute/registry.go index f33f6991e39e6..b749cd9d0e6f3 100644 --- a/go/arrow/compute/registry.go +++ b/go/arrow/compute/registry.go @@ -138,12 +138,19 @@ func (reg *funcRegistry) NumFunctions() (n int) { if reg.parent != nil { n = reg.parent.NumFunctions() } + reg.mx.RLock() + defer reg.mx.RUnlock() return n + len(reg.nameToFunction) } func (reg *funcRegistry) canAddFuncName(name string, allowOverwrite bool) bool { - if reg.parent != nil && !reg.parent.canAddFuncName(name, allowOverwrite) { - return false + if reg.parent != nil { + reg.parent.mx.RLock() + defer reg.parent.mx.RUnlock() + + if !reg.parent.canAddFuncName(name, allowOverwrite) { + return false + } } if !allowOverwrite { _, ok := reg.nameToFunction[name]