Skip to content

Commit

Permalink
GH-38918: [Go] Avoid schema.Fields allocations in some places (#38919)
Browse files Browse the repository at this point in the history
### Rationale for this change

Unnecessary allocations.

### What changes are included in this PR?

This PR is split into several commits. The first addresses allocations in the `dictutils` package, the second adds `NumFields` to `NestedType` so that the third commit, which is a purely mechanical change from `len(type.Fields())` to `type.NumFields` to avoid allocations in these specific cases can pass tests with no further changes.

The last commit removes some Fields allocations that specifically hurt our project. Note that this is not an all-encompassing change (therefore this PR should probably not close the linked issue).

### Are these changes tested?

These changes are implicitly tested by the existing test-suite. No functionality has been changed and they should be invisible to the user.

### Are there any user-facing changes?

No.

* Addresses: #38918
* Closes: #38918

Authored-by: Alfonso Subiotto Marques <alfonso.subiotto@polarsignals.com>
Signed-off-by: Matt Topol <zotthewizard@gmail.com>
  • Loading branch information
asubiotto authored Nov 28, 2023
1 parent 143b475 commit 82be255
Show file tree
Hide file tree
Showing 30 changed files with 99 additions and 69 deletions.
2 changes: 1 addition & 1 deletion go/arrow/array/concat.go
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ func concat(data []arrow.ArrayData, mem memory.Allocator) (arr arrow.ArrayData,
}
out.childData = []arrow.ArrayData{children}
case *arrow.StructType:
out.childData = make([]arrow.ArrayData, len(dt.Fields()))
out.childData = make([]arrow.ArrayData, dt.NumFields())
for i := range dt.Fields() {
children := gatherChildren(data, i)
for _, c := range children {
Expand Down
12 changes: 6 additions & 6 deletions go/arrow/array/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func (rec *simpleRecord) validate() error {
return nil
}

if len(rec.arrs) != len(rec.schema.Fields()) {
if len(rec.arrs) != rec.schema.NumFields() {
return fmt.Errorf("arrow/array: number of columns/fields mismatch")
}

Expand Down Expand Up @@ -285,11 +285,11 @@ func NewRecordBuilder(mem memory.Allocator, schema *arrow.Schema) *RecordBuilder
refCount: 1,
mem: mem,
schema: schema,
fields: make([]Builder, len(schema.Fields())),
fields: make([]Builder, schema.NumFields()),
}

for i, f := range schema.Fields() {
b.fields[i] = NewBuilder(b.mem, f.Type)
for i := 0; i < schema.NumFields(); i++ {
b.fields[i] = NewBuilder(b.mem, schema.Field(i).Type)
}

return b
Expand Down Expand Up @@ -397,8 +397,8 @@ func (b *RecordBuilder) UnmarshalJSON(data []byte) error {
}
}

for i, f := range b.schema.Fields() {
if !keylist[f.Name] {
for i := 0; i < b.schema.NumFields(); i++ {
if !keylist[b.schema.Field(i).Name] {
b.fields[i].AppendNull()
}
}
Expand Down
2 changes: 1 addition & 1 deletion go/arrow/array/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func NewStructBuilder(mem memory.Allocator, dtype *arrow.StructType) *StructBuil
b := &StructBuilder{
builder: builder{refCount: 1, mem: mem},
dtype: dtype,
fields: make([]Builder, len(dtype.Fields())),
fields: make([]Builder, dtype.NumFields()),
}
for i, f := range dtype.Fields() {
b.fields[i] = NewBuilder(b.mem, f.Type)
Expand Down
8 changes: 4 additions & 4 deletions go/arrow/array/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ func NewTable(schema *arrow.Schema, cols []arrow.Column, rows int64) *simpleTabl
// - the total length of each column's array slice (ie: number of rows
// in the column) aren't the same for all columns.
func NewTableFromSlice(schema *arrow.Schema, data [][]arrow.Array) *simpleTable {
if len(data) != len(schema.Fields()) {
if len(data) != schema.NumFields() {
panic("array/table: mismatch in number of columns and data for creating a table")
}

cols := make([]arrow.Column, len(schema.Fields()))
cols := make([]arrow.Column, schema.NumFields())
for i, arrs := range data {
field := schema.Field(i)
chunked := arrow.NewChunked(field.Type, arrs)
Expand Down Expand Up @@ -177,7 +177,7 @@ func NewTableFromSlice(schema *arrow.Schema, data [][]arrow.Array) *simpleTable
// NewTableFromRecords panics if the records and schema are inconsistent.
func NewTableFromRecords(schema *arrow.Schema, recs []arrow.Record) *simpleTable {
arrs := make([]arrow.Array, len(recs))
cols := make([]arrow.Column, len(schema.Fields()))
cols := make([]arrow.Column, schema.NumFields())

defer func(cols []arrow.Column) {
for i := range cols {
Expand Down Expand Up @@ -224,7 +224,7 @@ func (tbl *simpleTable) NumCols() int64 { return int64(len(tbl.cols)
func (tbl *simpleTable) Column(i int) *arrow.Column { return &tbl.cols[i] }

func (tbl *simpleTable) validate() {
if len(tbl.cols) != len(tbl.schema.Fields()) {
if len(tbl.cols) != tbl.schema.NumFields() {
panic(errors.New("arrow/array: table schema mismatch"))
}
for i, col := range tbl.cols {
Expand Down
4 changes: 2 additions & 2 deletions go/arrow/array/union.go
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ func NewEmptySparseUnionBuilder(mem memory.Allocator) *SparseUnionBuilder {
// children and type codes. Builders will be constructed for each child
// using the fields in typ
func NewSparseUnionBuilder(mem memory.Allocator, typ *arrow.SparseUnionType) *SparseUnionBuilder {
children := make([]Builder, len(typ.Fields()))
children := make([]Builder, typ.NumFields())
for i, f := range typ.Fields() {
children[i] = NewBuilder(mem, f.Type)
defer children[i].Release()
Expand Down Expand Up @@ -1129,7 +1129,7 @@ func NewEmptyDenseUnionBuilder(mem memory.Allocator) *DenseUnionBuilder {
// children and type codes. Builders will be constructed for each child
// using the fields in typ
func NewDenseUnionBuilder(mem memory.Allocator, typ *arrow.DenseUnionType) *DenseUnionBuilder {
children := make([]Builder, 0, len(typ.Fields()))
children := make([]Builder, 0, typ.NumFields())
defer func() {
for _, child := range children {
child.Release()
Expand Down
2 changes: 1 addition & 1 deletion go/arrow/array/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ func (n *nullArrayFactory) create() *Data {
}

if nf, ok := dt.(arrow.NestedType); ok {
childData = make([]arrow.ArrayData, len(nf.Fields()))
childData = make([]arrow.ArrayData, nf.NumFields())
}

switch dt := dt.(type) {
Expand Down
2 changes: 1 addition & 1 deletion go/arrow/cdata/cdata_exports.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func (exp *schemaExporter) export(field arrow.Field) {
exp.dict = new(schemaExporter)
exp.dict.export(arrow.Field{Type: dt.ValueType})
case arrow.NestedType:
exp.children = make([]schemaExporter, len(dt.Fields()))
exp.children = make([]schemaExporter, dt.NumFields())
for i, f := range dt.Fields() {
exp.children[i].export(f)
}
Expand Down
4 changes: 2 additions & 2 deletions go/arrow/compute/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ func CastStruct(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult)
opts = ctx.State.(kernels.CastState)
inType = batch.Values[0].Array.Type.(*arrow.StructType)
outType = out.Type.(*arrow.StructType)
inFieldCount = len(inType.Fields())
outFieldCount = len(outType.Fields())
inFieldCount = inType.NumFields()
outFieldCount = outType.NumFields()
)

fieldsToSelect := make([]int, outFieldCount)
Expand Down
2 changes: 1 addition & 1 deletion go/arrow/compute/exec/span.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ func FillZeroLength(dt arrow.DataType, span *ArraySpan) {
return
}

span.resizeChildren(len(nt.Fields()))
span.resizeChildren(nt.NumFields())
for i, f := range nt.Fields() {
FillZeroLength(f.Type, &span.Children[i])
}
Expand Down
2 changes: 1 addition & 1 deletion go/arrow/compute/exprs/builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func NewFieldRefFromDotPath(dotpath string, rootSchema *arrow.Schema) (expr.Refe
idx, _ := strconv.Atoi(dotpath[:subend])
switch ct := curType.(type) {
case *arrow.StructType:
if idx > len(ct.Fields()) {
if idx > ct.NumFields() {
return nil, fmt.Errorf("%w: field out of bounds in dotpath", arrow.ErrIndex)
}
curType = ct.Field(idx).Type
Expand Down
4 changes: 2 additions & 2 deletions go/arrow/compute/exprs/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Da
partialBatch := partial.(*compute.RecordDatum).Value
batchSchema := partialBatch.Schema()

out.Values = make([]compute.Datum, len(schema.Fields()))
out.Values = make([]compute.Datum, schema.NumFields())
out.Len = partialBatch.NumRows()

for i, field := range schema.Fields() {
Expand Down Expand Up @@ -99,7 +99,7 @@ func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Da
return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
case *compute.ScalarDatum:
out.Len = 1
out.Values = make([]compute.Datum, len(schema.Fields()))
out.Values = make([]compute.Datum, schema.NumFields())

s := part.Value.(*scalar.Struct)
dt := s.Type.(*arrow.StructType)
Expand Down
2 changes: 1 addition & 1 deletion go/arrow/compute/exprs/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ func ToSubstraitType(dt arrow.DataType, nullable bool, ext ExtensionIDSet) (type
Precision: dt.GetPrecision(), Scale: dt.GetScale()}, nil
case arrow.STRUCT:
dt := dt.(*arrow.StructType)
fields := make([]types.Type, len(dt.Fields()))
fields := make([]types.Type, dt.NumFields())
var err error
for i, f := range dt.Fields() {
fields[i], err = ToSubstraitType(f.Type, f.Nullable, ext)
Expand Down
4 changes: 2 additions & 2 deletions go/arrow/compute/fieldref_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestFieldPathBasics(t *testing.T) {
assert.Nil(t, f)
assert.ErrorIs(t, err, compute.ErrEmpty)

f, err = compute.FieldPath{len(s.Fields()) * 2}.Get(s)
f, err = compute.FieldPath{s.NumFields() * 2}.Get(s)
assert.Nil(t, f)
assert.ErrorIs(t, err, compute.ErrIndexRange)
}
Expand All @@ -63,7 +63,7 @@ func TestFieldRefBasics(t *testing.T) {
}

// out of range index results in failure to match
assert.Empty(t, compute.FieldRefIndex(len(s.Fields())*2).FindAll(s.Fields()))
assert.Empty(t, compute.FieldRefIndex(s.NumFields()*2).FindAll(s.Fields()))

// lookup by name returns the indices of both matching fields
assert.Equal(t, []compute.FieldPath{{0}, {2}}, compute.FieldRefName("alpha").FindAll(s.Fields()))
Expand Down
2 changes: 2 additions & 0 deletions go/arrow/datatype_encoded.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ func (t *RunEndEncodedType) Fields() []Field {
}
}

func (t *RunEndEncodedType) NumFields() int { return 2 }

func (*RunEndEncodedType) ValidRunEndsType(dt DataType) bool {
switch dt.ID() {
case INT16, INT32, INT64:
Expand Down
7 changes: 7 additions & 0 deletions go/arrow/datatype_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ func (e *ExtensionBase) Fields() []Field {
return nil
}

func (e *ExtensionBase) NumFields() int {
if nested, ok := e.Storage.(NestedType); ok {
return nested.NumFields()
}
return 0
}

func (e *ExtensionBase) Layout() DataTypeLayout { return e.Storage.Layout() }

// this no-op exists to ensure that this type must be embedded in any user-defined extension type.
Expand Down
16 changes: 16 additions & 0 deletions go/arrow/datatype_nested.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ type (
// Fields method provides a copy of NestedType fields
// (so it can be safely mutated and will not result in updating the NestedType).
Fields() []Field
// NumFields provides the number of fields without allocating.
NumFields() int
}

ListLikeType interface {
Expand Down Expand Up @@ -109,6 +111,8 @@ func (t *ListType) ElemField() Field {

func (t *ListType) Fields() []Field { return []Field{t.ElemField()} }

func (t *ListType) NumFields() int { return 1 }

func (*ListType) Layout() DataTypeLayout {
return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(), SpecFixedWidth(Int32SizeBytes)}}
}
Expand Down Expand Up @@ -242,6 +246,8 @@ func (t *FixedSizeListType) Fingerprint() string {

func (t *FixedSizeListType) Fields() []Field { return []Field{t.ElemField()} }

func (t *FixedSizeListType) NumFields() int { return 1 }

func (*FixedSizeListType) Layout() DataTypeLayout {
return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap()}}
}
Expand Down Expand Up @@ -308,6 +314,8 @@ func (t *ListViewType) ElemField() Field {

func (t *ListViewType) Fields() []Field { return []Field{t.ElemField()} }

func (t *ListViewType) NumFields() int { return 1 }

func (*ListViewType) Layout() DataTypeLayout {
return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(), SpecFixedWidth(Int32SizeBytes), SpecFixedWidth(Int32SizeBytes)}}
}
Expand Down Expand Up @@ -376,6 +384,8 @@ func (t *LargeListViewType) ElemField() Field {

func (t *LargeListViewType) Fields() []Field { return []Field{t.ElemField()} }

func (t *LargeListViewType) NumFields() int { return 1 }

func (*LargeListViewType) Layout() DataTypeLayout {
return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(), SpecFixedWidth(Int64SizeBytes), SpecFixedWidth(Int64SizeBytes)}}
}
Expand Down Expand Up @@ -447,6 +457,8 @@ func (t *StructType) Fields() []Field {
return fields
}

func (t *StructType) NumFields() int { return len(t.fields) }

func (t *StructType) Field(i int) Field { return t.fields[i] }

// FieldByName gets the field with the given name.
Expand Down Expand Up @@ -598,6 +610,8 @@ func (t *MapType) Fingerprint() string {

func (t *MapType) Fields() []Field { return []Field{t.ElemField()} }

func (t *MapType) NumFields() int { return 1 }

func (t *MapType) Layout() DataTypeLayout {
return t.value.Layout()
}
Expand Down Expand Up @@ -690,6 +704,8 @@ func (t *unionType) Fields() []Field {
return fields
}

func (t *unionType) NumFields() int { return len(t.children) }

func (t *unionType) TypeCodes() []UnionTypeCode { return t.typeCodes }
func (t *unionType) ChildIDs() []int { return t.childIDs[:] }

Expand Down
2 changes: 1 addition & 1 deletion go/arrow/datatype_nested_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func TestStructOf(t *testing.T) {
t.Fatalf("invalid name. got=%q, want=%q", got, want)
}

if got, want := len(got.Fields()), len(tc.fields); got != want {
if got, want := got.NumFields(), len(tc.fields); got != want {
t.Fatalf("invalid number of fields. got=%d, want=%d", got, want)
}

Expand Down
2 changes: 1 addition & 1 deletion go/arrow/flight/flightsql/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (s *Stmt) NumInput() int {
// If NumInput returns >= 0, the sql package will sanity check argument
// counts from callers and return errors to the caller before the
// statement's Exec or Query methods are called.
return len(schema.Fields())
return schema.NumFields()
}

// Exec executes a query that doesn't return rows, such
Expand Down
2 changes: 1 addition & 1 deletion go/arrow/flight/flightsql/example/sql_batch_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ type SqlBatchReader struct {
}

func NewSqlBatchReaderWithSchema(mem memory.Allocator, schema *arrow.Schema, rows *sql.Rows) (*SqlBatchReader, error) {
rowdest := make([]interface{}, len(schema.Fields()))
rowdest := make([]interface{}, schema.NumFields())
for i, f := range schema.Fields() {
switch f.Type.ID() {
case arrow.DENSE_UNION, arrow.SPARSE_UNION:
Expand Down
8 changes: 4 additions & 4 deletions go/arrow/internal/arrjson/arrjson.go
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr

nulls := arr.Count - bitutil.CountSetBits(bitmap.Bytes(), 0, arr.Count)

fields := make([]arrow.ArrayData, len(dt.Fields()))
fields := make([]arrow.ArrayData, dt.NumFields())
for i := range fields {
child := arrayFromJSON(mem, dt.Field(i).Type, arr.Children[i])
defer child.Release()
Expand Down Expand Up @@ -1328,7 +1328,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
return array.NewData(dt, arr.Count, []*memory.Buffer{nil}, []arrow.ArrayData{runEnds, values}, 0, 0)

case arrow.UnionType:
fields := make([]arrow.ArrayData, len(dt.Fields()))
fields := make([]arrow.ArrayData, dt.NumFields())
for i, f := range dt.Fields() {
child := arrayFromJSON(mem, f.Type, arr.Children[i])
defer child.Release()
Expand Down Expand Up @@ -1620,7 +1620,7 @@ func arrayToJSON(field arrow.Field, arr arrow.Array) Array {
Name: field.Name,
Count: arr.Len(),
Valids: validsToJSON(arr),
Children: make([]Array, len(dt.Fields())),
Children: make([]Array, dt.NumFields()),
}
for i := range o.Children {
o.Children[i] = arrayToJSON(dt.Field(i), arr.Field(i))
Expand Down Expand Up @@ -1741,7 +1741,7 @@ func arrayToJSON(field arrow.Field, arr arrow.Array) Array {
Count: arr.Len(),
Valids: validsToJSON(arr),
TypeID: arr.RawTypeCodes(),
Children: make([]Array, len(dt.Fields())),
Children: make([]Array, dt.NumFields()),
}
if dt.Mode() == arrow.DenseMode {
o.Offset = arr.(*array.DenseUnion).RawValueOffsets()
Expand Down
11 changes: 8 additions & 3 deletions go/arrow/internal/dictutils/dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (d *Mapper) InsertPath(pos FieldPos) {
d.hasher.Reset()
}

func (d *Mapper) ImportField(pos FieldPos, field *arrow.Field) {
func (d *Mapper) ImportField(pos FieldPos, field arrow.Field) {
dt := field.Type
if dt.ID() == arrow.EXTENSION {
dt = dt.(arrow.ExtensionType).StorageType()
Expand All @@ -126,13 +126,18 @@ func (d *Mapper) ImportField(pos FieldPos, field *arrow.Field) {

func (d *Mapper) ImportFields(pos FieldPos, fields []arrow.Field) {
for i := range fields {
d.ImportField(pos.Child(int32(i)), &fields[i])
d.ImportField(pos.Child(int32(i)), fields[i])
}
}

func (d *Mapper) ImportSchema(schema *arrow.Schema) {
d.pathToID = make(map[uint64]int64)
d.ImportFields(NewFieldPos(), schema.Fields())
// This code path intentionally avoids calling ImportFields with
// schema.Fields to avoid allocations.
pos := NewFieldPos()
for i := 0; i < schema.NumFields(); i++ {
d.ImportField(pos.Child(int32(i)), schema.Field(i))
}
}

func hasUnresolvedNestedDict(data arrow.ArrayData) bool {
Expand Down
Loading

0 comments on commit 82be255

Please sign in to comment.