From c8e76923621b379fb7deb6dfb944011af1d980bd Mon Sep 17 00:00:00 2001 From: Alvaro Viebrantz Date: Mon, 23 Oct 2023 09:08:39 -0700 Subject: [PATCH] feat(bigquery): expose Apache Arrow data through ArrowIterator (#8506) As we have some planned work to support Arrow data fetching on other query APIs, so we need to think of an interface that will support all of those query paths and also work as a base for other Arrow projects like ADBC. So this PR detaches the Storage API from the Arrow Decoder and creates a new ArrowIterator interface. This new interface is implemented by the Storage iterator and later can be implemented for other query interfaces that supports Arrow. Resolves #8100 --- bigquery/arrow.go | 105 +++++++++++++++++++++----- bigquery/iterator.go | 3 +- bigquery/storage_bench_test.go | 2 +- bigquery/storage_integration_test.go | 93 +++++++++++++++++++++-- bigquery/storage_iterator.go | 107 ++++++++++++++++----------- bigquery/storage_iterator_test.go | 2 +- 6 files changed, 241 insertions(+), 71 deletions(-) diff --git a/bigquery/arrow.go b/bigquery/arrow.go index c5120901b530..bd7f0c3ebf26 100644 --- a/bigquery/arrow.go +++ b/bigquery/arrow.go @@ -19,26 +19,89 @@ import ( "encoding/base64" "errors" "fmt" + "io" "math/big" "cloud.google.com/go/civil" "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/ipc" + "github.com/apache/arrow/go/v12/arrow/memory" + "google.golang.org/api/iterator" ) -type arrowDecoder struct { - tableSchema Schema - rawArrowSchema []byte - arrowSchema *arrow.Schema +// ArrowRecordBatch represents an Arrow RecordBatch with the source PartitionID +type ArrowRecordBatch struct { + reader io.Reader + // Serialized Arrow Record Batch. + Data []byte + // Serialized Arrow Schema. + Schema []byte + // Source partition ID. In the Storage API world, it represents the ReadStream. + PartitionID string +} + +// Read makes ArrowRecordBatch implements io.Reader +func (r *ArrowRecordBatch) Read(p []byte) (int, error) { + if r.reader == nil { + buf := bytes.NewBuffer(r.Schema) + buf.Write(r.Data) + r.reader = buf + } + return r.reader.Read(p) +} + +// ArrowIterator represents a way to iterate through a stream of arrow records. +// Experimental: this interface is experimental and may be modified or removed in future versions, +// regardless of any other documented package stability guarantees. +type ArrowIterator interface { + Next() (*ArrowRecordBatch, error) + Schema() Schema + SerializedArrowSchema() []byte } -func newArrowDecoderFromSession(session *readSession, schema Schema) (*arrowDecoder, error) { - bqSession := session.bqSession - if bqSession == nil { - return nil, errors.New("read session not initialized") +// NewArrowIteratorReader allows to consume an ArrowIterator as an io.Reader. +// Experimental: this interface is experimental and may be modified or removed in future versions, +// regardless of any other documented package stability guarantees. +func NewArrowIteratorReader(it ArrowIterator) io.Reader { + return &arrowIteratorReader{ + it: it, } - arrowSerializedSchema := bqSession.GetArrowSchema().GetSerializedSchema() +} + +type arrowIteratorReader struct { + buf *bytes.Buffer + it ArrowIterator +} + +// Read makes ArrowIteratorReader implement io.Reader +func (r *arrowIteratorReader) Read(p []byte) (int, error) { + if r.it == nil { + return -1, errors.New("bigquery: nil ArrowIterator") + } + if r.buf == nil { // init with schema + buf := bytes.NewBuffer(r.it.SerializedArrowSchema()) + r.buf = buf + } + n, err := r.buf.Read(p) + if err == io.EOF { + batch, err := r.it.Next() + if err == iterator.Done { + return 0, io.EOF + } + r.buf.Write(batch.Data) + return r.Read(p) + } + return n, err +} + +type arrowDecoder struct { + allocator memory.Allocator + tableSchema Schema + arrowSchema *arrow.Schema +} + +func newArrowDecoder(arrowSerializedSchema []byte, schema Schema) (*arrowDecoder, error) { buf := bytes.NewBuffer(arrowSerializedSchema) r, err := ipc.NewReader(buf) if err != nil { @@ -46,22 +109,24 @@ func newArrowDecoderFromSession(session *readSession, schema Schema) (*arrowDeco } defer r.Release() p := &arrowDecoder{ - tableSchema: schema, - rawArrowSchema: arrowSerializedSchema, - arrowSchema: r.Schema(), + tableSchema: schema, + arrowSchema: r.Schema(), + allocator: memory.DefaultAllocator, } return p, nil } -func (ap *arrowDecoder) createIPCReaderForBatch(serializedArrowRecordBatch []byte) (*ipc.Reader, error) { - buf := bytes.NewBuffer(ap.rawArrowSchema) - buf.Write(serializedArrowRecordBatch) - return ipc.NewReader(buf, ipc.WithSchema(ap.arrowSchema)) +func (ap *arrowDecoder) createIPCReaderForBatch(arrowRecordBatch *ArrowRecordBatch) (*ipc.Reader, error) { + return ipc.NewReader( + arrowRecordBatch, + ipc.WithSchema(ap.arrowSchema), + ipc.WithAllocator(ap.allocator), + ) } // decodeArrowRecords decodes BQ ArrowRecordBatch into rows of []Value. -func (ap *arrowDecoder) decodeArrowRecords(serializedArrowRecordBatch []byte) ([][]Value, error) { - r, err := ap.createIPCReaderForBatch(serializedArrowRecordBatch) +func (ap *arrowDecoder) decodeArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([][]Value, error) { + r, err := ap.createIPCReaderForBatch(arrowRecordBatch) if err != nil { return nil, err } @@ -79,8 +144,8 @@ func (ap *arrowDecoder) decodeArrowRecords(serializedArrowRecordBatch []byte) ([ } // decodeRetainedArrowRecords decodes BQ ArrowRecordBatch into a list of retained arrow.Record. -func (ap *arrowDecoder) decodeRetainedArrowRecords(serializedArrowRecordBatch []byte) ([]arrow.Record, error) { - r, err := ap.createIPCReaderForBatch(serializedArrowRecordBatch) +func (ap *arrowDecoder) decodeRetainedArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([]arrow.Record, error) { + r, err := ap.createIPCReaderForBatch(arrowRecordBatch) if err != nil { return nil, err } diff --git a/bigquery/iterator.go b/bigquery/iterator.go index b5823e7f4a73..37b12dea84d4 100644 --- a/bigquery/iterator.go +++ b/bigquery/iterator.go @@ -44,7 +44,8 @@ type RowIterator struct { ctx context.Context src *rowSource - arrowIterator *arrowIterator + arrowIterator ArrowIterator + arrowDecoder *arrowDecoder pageInfo *iterator.PageInfo nextFunc func() error diff --git a/bigquery/storage_bench_test.go b/bigquery/storage_bench_test.go index 690a561f02fc..53c9feea5bea 100644 --- a/bigquery/storage_bench_test.go +++ b/bigquery/storage_bench_test.go @@ -74,7 +74,7 @@ func BenchmarkIntegration_StorageReadQuery(b *testing.B) { } } b.ReportMetric(float64(it.TotalRows), "rows") - bqSession := it.arrowIterator.session.bqSession + bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession b.ReportMetric(float64(len(bqSession.Streams)), "parallel_streams") b.ReportMetric(float64(maxStreamCount), "max_streams") } diff --git a/bigquery/storage_integration_test.go b/bigquery/storage_integration_test.go index 88924ec758a9..ed4114f56d0a 100644 --- a/bigquery/storage_integration_test.go +++ b/bigquery/storage_integration_test.go @@ -22,6 +22,11 @@ import ( "time" "cloud.google.com/go/internal/testutil" + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/ipc" + "github.com/apache/arrow/go/v12/arrow/math" + "github.com/apache/arrow/go/v12/arrow/memory" "github.com/google/go-cmp/cmp" "google.golang.org/api/iterator" ) @@ -250,11 +255,12 @@ func TestIntegration_StorageReadQueryOrdering(t *testing.T) { } total++ // as we read the first value separately - bqSession := it.arrowIterator.session.bqSession + session := it.arrowIterator.(*storageArrowIterator).session + bqSession := session.bqSession if len(bqSession.Streams) == 0 { t.Fatalf("%s: expected to use at least one stream but found %d", tc.name, len(bqSession.Streams)) } - streamSettings := it.arrowIterator.session.settings.maxStreamCount + streamSettings := session.settings.maxStreamCount if tc.maxExpectedStreams > 0 { if streamSettings > tc.maxExpectedStreams { t.Fatalf("%s: expected stream settings to be at most %d streams but found %d", tc.name, tc.maxExpectedStreams, streamSettings) @@ -317,7 +323,7 @@ func TestIntegration_StorageReadQueryStruct(t *testing.T) { total++ } - bqSession := it.arrowIterator.session.bqSession + bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession if len(bqSession.Streams) == 0 { t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams)) } @@ -366,7 +372,7 @@ func TestIntegration_StorageReadQueryMorePages(t *testing.T) { } total++ // as we read the first value separately - bqSession := it.arrowIterator.session.bqSession + bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession if len(bqSession.Streams) == 0 { t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams)) } @@ -418,11 +424,88 @@ func TestIntegration_StorageReadCancel(t *testing.T) { } // resources are cleaned asynchronously time.Sleep(time.Second) - if !it.arrowIterator.isDone() { + arrowIt := it.arrowIterator.(*storageArrowIterator) + if !arrowIt.isDone() { t.Fatal("expected stream to be done") } } +func TestIntegration_StorageReadArrow(t *testing.T) { + if client == nil { + t.Skip("Integration tests skipped") + } + ctx := context.Background() + table := "`bigquery-public-data.usa_names.usa_1910_current`" + sql := fmt.Sprintf(`SELECT name, number, state FROM %s where state = "CA"`, table) + + q := storageOptimizedClient.Query(sql) + job, err := q.Run(ctx) // force usage of Storage API by skipping fast paths + if err != nil { + t.Fatal(err) + } + it, err := job.Read(ctx) + if err != nil { + t.Fatal(err) + } + + checkedAllocator := memory.NewCheckedAllocator(memory.DefaultAllocator) + it.arrowDecoder.allocator = checkedAllocator + defer checkedAllocator.AssertSize(t, 0) + + arrowIt, err := it.ArrowIterator() + if err != nil { + t.Fatalf("expected iterator to be accelerated: %v", err) + } + arrowItReader := NewArrowIteratorReader(arrowIt) + + records := []arrow.Record{} + r, err := ipc.NewReader(arrowItReader, ipc.WithAllocator(checkedAllocator)) + numrec := 0 + for r.Next() { + rec := r.Record() + rec.Retain() + defer rec.Release() + records = append(records, rec) + numrec += int(rec.NumRows()) + } + r.Release() + + arrowSchema := r.Schema() + arrowTable := array.NewTableFromRecords(arrowSchema, records) + defer arrowTable.Release() + if arrowTable.NumRows() != int64(it.TotalRows) { + t.Fatalf("should have a table with %d rows, but found %d", it.TotalRows, arrowTable.NumRows()) + } + if arrowTable.NumCols() != 3 { + t.Fatalf("should have a table with 3 columns, but found %d", arrowTable.NumCols()) + } + + sumSQL := fmt.Sprintf(`SELECT sum(number) as total FROM %s where state = "CA"`, table) + sumQuery := client.Query(sumSQL) + sumIt, err := sumQuery.Read(ctx) + if err != nil { + t.Fatal(err) + } + sumValues := []Value{} + err = sumIt.Next(&sumValues) + if err != nil { + t.Fatal(err) + } + totalFromSQL := sumValues[0].(int64) + + tr := array.NewTableReader(arrowTable, arrowTable.NumRows()) + defer tr.Release() + var totalFromArrow int64 + for tr.Next() { + rec := tr.Record() + vec := rec.Column(1).(*array.Int64) + totalFromArrow += math.Int64.Sum(vec) + } + if totalFromArrow != totalFromSQL { + t.Fatalf("expected total to be %d, but with arrow we got %d", totalFromSQL, totalFromArrow) + } +} + func countIteratorRows(it *RowIterator) (total uint64, err error) { for { var dst []Value diff --git a/bigquery/storage_iterator.go b/bigquery/storage_iterator.go index ae5ae2a66d17..96e12f869a77 100644 --- a/bigquery/storage_iterator.go +++ b/bigquery/storage_iterator.go @@ -32,20 +32,21 @@ import ( "google.golang.org/grpc/status" ) -// arrowIterator is a raw interface for getting data from Storage Read API -type arrowIterator struct { - done uint32 // atomic flag - errs chan error - ctx context.Context +// storageArrowIterator is a raw interface for getting data from Storage Read API +type storageArrowIterator struct { + done uint32 // atomic flag + initialized bool + errs chan error + ctx context.Context - schema Schema - decoder *arrowDecoder - records chan arrowRecordBatch + schema Schema + rawSchema []byte + records chan *ArrowRecordBatch session *readSession } -type arrowRecordBatch []byte +var _ ArrowIterator = &storageArrowIterator{} func newStorageRowIteratorFromTable(ctx context.Context, table *Table, ordered bool) (*RowIterator, error) { md, err := table.Metadata(ctx) @@ -56,11 +57,19 @@ func newStorageRowIteratorFromTable(ctx context.Context, table *Table, ordered b if err != nil { return nil, err } - it, err := newStorageRowIterator(rs) + it, err := newStorageRowIterator(rs, md.Schema) if err != nil { return nil, err } - it.arrowIterator.schema = md.Schema + if rs.bqSession == nil { + return nil, errors.New("read session not initialized") + } + arrowSerializedSchema := rs.bqSession.GetArrowSchema().GetSerializedSchema() + dec, err := newArrowDecoder(arrowSerializedSchema, md.Schema) + if err != nil { + return nil, err + } + it.arrowDecoder = dec it.Schema = md.Schema return it, nil } @@ -112,11 +121,12 @@ func resolveLastChildSelectJob(ctx context.Context, job *Job) (*Job, error) { return childJobs[0], nil } -func newRawStorageRowIterator(rs *readSession) (*arrowIterator, error) { - arrowIt := &arrowIterator{ +func newRawStorageRowIterator(rs *readSession, schema Schema) (*storageArrowIterator, error) { + arrowIt := &storageArrowIterator{ ctx: rs.ctx, session: rs, - records: make(chan arrowRecordBatch, rs.settings.maxWorkerCount+1), + schema: schema, + records: make(chan *ArrowRecordBatch, rs.settings.maxWorkerCount+1), errs: make(chan error, rs.settings.maxWorkerCount+1), } if rs.bqSession == nil { @@ -125,11 +135,12 @@ func newRawStorageRowIterator(rs *readSession) (*arrowIterator, error) { return nil, err } } + arrowIt.rawSchema = rs.bqSession.GetArrowSchema().GetSerializedSchema() return arrowIt, nil } -func newStorageRowIterator(rs *readSession) (*RowIterator, error) { - arrowIt, err := newRawStorageRowIterator(rs) +func newStorageRowIterator(rs *readSession, schema Schema) (*RowIterator, error) { + arrowIt, err := newRawStorageRowIterator(rs, schema) if err != nil { return nil, err } @@ -153,8 +164,7 @@ func nextFuncForStorageIterator(it *RowIterator) func() error { if len(it.rows) > 0 { return nil } - arrowIt := it.arrowIterator - record, err := arrowIt.next() + record, err := it.arrowIterator.Next() if err == iterator.Done { if len(it.rows) == 0 { return iterator.Done @@ -165,9 +175,9 @@ func nextFuncForStorageIterator(it *RowIterator) func() error { return err } if it.Schema == nil { - it.Schema = it.arrowIterator.schema + it.Schema = it.arrowIterator.Schema() } - rows, err := arrowIt.decoder.decodeArrowRecords(record) + rows, err := it.arrowDecoder.decodeArrowRecords(record) if err != nil { return err } @@ -176,8 +186,8 @@ func nextFuncForStorageIterator(it *RowIterator) func() error { } } -func (it *arrowIterator) init() error { - if it.decoder != nil { // Already initialized +func (it *storageArrowIterator) init() error { + if it.initialized { return nil } @@ -191,20 +201,6 @@ func (it *arrowIterator) init() error { return iterator.Done } - if it.schema == nil { - meta, err := it.session.table.Metadata(it.ctx) - if err != nil { - return err - } - it.schema = meta.Schema - } - - decoder, err := newArrowDecoderFromSession(it.session, it.schema) - if err != nil { - return err - } - it.decoder = decoder - wg := sync.WaitGroup{} wg.Add(len(streams)) sem := semaphore.NewWeighted(int64(it.session.settings.maxWorkerCount)) @@ -229,18 +225,19 @@ func (it *arrowIterator) init() error { }(readStream.Name) } }() + it.initialized = true return nil } -func (it *arrowIterator) markDone() { +func (it *storageArrowIterator) markDone() { atomic.StoreUint32(&it.done, 1) } -func (it *arrowIterator) isDone() bool { +func (it *storageArrowIterator) isDone() bool { return atomic.LoadUint32(&it.done) != 0 } -func (it *arrowIterator) processStream(readStream string) { +func (it *storageArrowIterator) processStream(readStream string) { bo := gax.Backoff{} var offset int64 for { @@ -300,7 +297,7 @@ func retryReadRows(bo gax.Backoff, err error) (time.Duration, bool) { return bo.Pause(), false } -func (it *arrowIterator) consumeRowStream(readStream string, rowStream storagepb.BigQueryRead_ReadRowsClient, offset int64) (int64, error) { +func (it *storageArrowIterator) consumeRowStream(readStream string, rowStream storagepb.BigQueryRead_ReadRowsClient, offset int64) (int64, error) { for { r, err := rowStream.Recv() if err != nil { @@ -311,8 +308,12 @@ func (it *arrowIterator) consumeRowStream(readStream string, rowStream storagepb } if r.RowCount > 0 { offset += r.RowCount - arrowRecordBatch := r.GetArrowRecordBatch() - it.records <- arrowRecordBatch.SerializedRecordBatch + recordBatch := r.GetArrowRecordBatch() + it.records <- &ArrowRecordBatch{ + PartitionID: readStream, + Schema: it.rawSchema, + Data: recordBatch.SerializedRecordBatch, + } } } } @@ -320,7 +321,7 @@ func (it *arrowIterator) consumeRowStream(readStream string, rowStream storagepb // next return the next batch of rows as an arrow.Record. // Accessing Arrow Records directly has the drawnback of having to deal // with memory management. -func (it *arrowIterator) next() (arrowRecordBatch, error) { +func (it *storageArrowIterator) Next() (*ArrowRecordBatch, error) { if err := it.init(); err != nil { return nil, err } @@ -343,8 +344,28 @@ func (it *arrowIterator) next() (arrowRecordBatch, error) { } } +func (it *storageArrowIterator) SerializedArrowSchema() []byte { + return it.rawSchema +} + +func (it *storageArrowIterator) Schema() Schema { + return it.schema +} + // IsAccelerated check if the current RowIterator is // being accelerated by Storage API. func (it *RowIterator) IsAccelerated() bool { return it.arrowIterator != nil } + +// ArrowIterator gives access to the raw Arrow Record Batch stream to be consumed directly. +// Experimental: this interface is experimental and may be modified or removed in future versions, +// regardless of any other documented package stability guarantees. +// Don't try to mix RowIterator.Next and ArrowIterator.Next calls. +func (it *RowIterator) ArrowIterator() (ArrowIterator, error) { + if !it.IsAccelerated() { + // TODO: can we convert plain RowIterator based on JSON API to an Arrow Stream ? + return nil, errors.New("bigquery: require storage read API to be enabled") + } + return it.arrowIterator, nil +} diff --git a/bigquery/storage_iterator_test.go b/bigquery/storage_iterator_test.go index 75e0d87973a1..10471184f544 100644 --- a/bigquery/storage_iterator_test.go +++ b/bigquery/storage_iterator_test.go @@ -110,7 +110,7 @@ func TestStorageIteratorRetry(t *testing.T) { settings: defaultReadClientSettings(), readRowsFunc: readRowsFunc, bqSession: &storagepb.ReadSession{}, - }) + }, Schema{}) if err != nil { t.Fatalf("case %s: newRawStorageRowIterator: %v", tc.desc, err) }