diff --git a/bigquery/arrow.go b/bigquery/arrow.go index c5120901b530..bd7f0c3ebf26 100644 --- a/bigquery/arrow.go +++ b/bigquery/arrow.go @@ -19,26 +19,89 @@ import ( "encoding/base64" "errors" "fmt" + "io" "math/big" "cloud.google.com/go/civil" "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/ipc" + "github.com/apache/arrow/go/v12/arrow/memory" + "google.golang.org/api/iterator" ) -type arrowDecoder struct { - tableSchema Schema - rawArrowSchema []byte - arrowSchema *arrow.Schema +// ArrowRecordBatch represents an Arrow RecordBatch with the source PartitionID +type ArrowRecordBatch struct { + reader io.Reader + // Serialized Arrow Record Batch. + Data []byte + // Serialized Arrow Schema. + Schema []byte + // Source partition ID. In the Storage API world, it represents the ReadStream. + PartitionID string +} + +// Read makes ArrowRecordBatch implements io.Reader +func (r *ArrowRecordBatch) Read(p []byte) (int, error) { + if r.reader == nil { + buf := bytes.NewBuffer(r.Schema) + buf.Write(r.Data) + r.reader = buf + } + return r.reader.Read(p) +} + +// ArrowIterator represents a way to iterate through a stream of arrow records. +// Experimental: this interface is experimental and may be modified or removed in future versions, +// regardless of any other documented package stability guarantees. +type ArrowIterator interface { + Next() (*ArrowRecordBatch, error) + Schema() Schema + SerializedArrowSchema() []byte } -func newArrowDecoderFromSession(session *readSession, schema Schema) (*arrowDecoder, error) { - bqSession := session.bqSession - if bqSession == nil { - return nil, errors.New("read session not initialized") +// NewArrowIteratorReader allows to consume an ArrowIterator as an io.Reader. +// Experimental: this interface is experimental and may be modified or removed in future versions, +// regardless of any other documented package stability guarantees. +func NewArrowIteratorReader(it ArrowIterator) io.Reader { + return &arrowIteratorReader{ + it: it, } - arrowSerializedSchema := bqSession.GetArrowSchema().GetSerializedSchema() +} + +type arrowIteratorReader struct { + buf *bytes.Buffer + it ArrowIterator +} + +// Read makes ArrowIteratorReader implement io.Reader +func (r *arrowIteratorReader) Read(p []byte) (int, error) { + if r.it == nil { + return -1, errors.New("bigquery: nil ArrowIterator") + } + if r.buf == nil { // init with schema + buf := bytes.NewBuffer(r.it.SerializedArrowSchema()) + r.buf = buf + } + n, err := r.buf.Read(p) + if err == io.EOF { + batch, err := r.it.Next() + if err == iterator.Done { + return 0, io.EOF + } + r.buf.Write(batch.Data) + return r.Read(p) + } + return n, err +} + +type arrowDecoder struct { + allocator memory.Allocator + tableSchema Schema + arrowSchema *arrow.Schema +} + +func newArrowDecoder(arrowSerializedSchema []byte, schema Schema) (*arrowDecoder, error) { buf := bytes.NewBuffer(arrowSerializedSchema) r, err := ipc.NewReader(buf) if err != nil { @@ -46,22 +109,24 @@ func newArrowDecoderFromSession(session *readSession, schema Schema) (*arrowDeco } defer r.Release() p := &arrowDecoder{ - tableSchema: schema, - rawArrowSchema: arrowSerializedSchema, - arrowSchema: r.Schema(), + tableSchema: schema, + arrowSchema: r.Schema(), + allocator: memory.DefaultAllocator, } return p, nil } -func (ap *arrowDecoder) createIPCReaderForBatch(serializedArrowRecordBatch []byte) (*ipc.Reader, error) { - buf := bytes.NewBuffer(ap.rawArrowSchema) - buf.Write(serializedArrowRecordBatch) - return ipc.NewReader(buf, ipc.WithSchema(ap.arrowSchema)) +func (ap *arrowDecoder) createIPCReaderForBatch(arrowRecordBatch *ArrowRecordBatch) (*ipc.Reader, error) { + return ipc.NewReader( + arrowRecordBatch, + ipc.WithSchema(ap.arrowSchema), + ipc.WithAllocator(ap.allocator), + ) } // decodeArrowRecords decodes BQ ArrowRecordBatch into rows of []Value. -func (ap *arrowDecoder) decodeArrowRecords(serializedArrowRecordBatch []byte) ([][]Value, error) { - r, err := ap.createIPCReaderForBatch(serializedArrowRecordBatch) +func (ap *arrowDecoder) decodeArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([][]Value, error) { + r, err := ap.createIPCReaderForBatch(arrowRecordBatch) if err != nil { return nil, err } @@ -79,8 +144,8 @@ func (ap *arrowDecoder) decodeArrowRecords(serializedArrowRecordBatch []byte) ([ } // decodeRetainedArrowRecords decodes BQ ArrowRecordBatch into a list of retained arrow.Record. -func (ap *arrowDecoder) decodeRetainedArrowRecords(serializedArrowRecordBatch []byte) ([]arrow.Record, error) { - r, err := ap.createIPCReaderForBatch(serializedArrowRecordBatch) +func (ap *arrowDecoder) decodeRetainedArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([]arrow.Record, error) { + r, err := ap.createIPCReaderForBatch(arrowRecordBatch) if err != nil { return nil, err } diff --git a/bigquery/iterator.go b/bigquery/iterator.go index b5823e7f4a73..37b12dea84d4 100644 --- a/bigquery/iterator.go +++ b/bigquery/iterator.go @@ -44,7 +44,8 @@ type RowIterator struct { ctx context.Context src *rowSource - arrowIterator *arrowIterator + arrowIterator ArrowIterator + arrowDecoder *arrowDecoder pageInfo *iterator.PageInfo nextFunc func() error diff --git a/bigquery/storage_bench_test.go b/bigquery/storage_bench_test.go index 690a561f02fc..53c9feea5bea 100644 --- a/bigquery/storage_bench_test.go +++ b/bigquery/storage_bench_test.go @@ -74,7 +74,7 @@ func BenchmarkIntegration_StorageReadQuery(b *testing.B) { } } b.ReportMetric(float64(it.TotalRows), "rows") - bqSession := it.arrowIterator.session.bqSession + bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession b.ReportMetric(float64(len(bqSession.Streams)), "parallel_streams") b.ReportMetric(float64(maxStreamCount), "max_streams") } diff --git a/bigquery/storage_integration_test.go b/bigquery/storage_integration_test.go index 88924ec758a9..ed4114f56d0a 100644 --- a/bigquery/storage_integration_test.go +++ b/bigquery/storage_integration_test.go @@ -22,6 +22,11 @@ import ( "time" "cloud.google.com/go/internal/testutil" + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/ipc" + "github.com/apache/arrow/go/v12/arrow/math" + "github.com/apache/arrow/go/v12/arrow/memory" "github.com/google/go-cmp/cmp" "google.golang.org/api/iterator" ) @@ -250,11 +255,12 @@ func TestIntegration_StorageReadQueryOrdering(t *testing.T) { } total++ // as we read the first value separately - bqSession := it.arrowIterator.session.bqSession + session := it.arrowIterator.(*storageArrowIterator).session + bqSession := session.bqSession if len(bqSession.Streams) == 0 { t.Fatalf("%s: expected to use at least one stream but found %d", tc.name, len(bqSession.Streams)) } - streamSettings := it.arrowIterator.session.settings.maxStreamCount + streamSettings := session.settings.maxStreamCount if tc.maxExpectedStreams > 0 { if streamSettings > tc.maxExpectedStreams { t.Fatalf("%s: expected stream settings to be at most %d streams but found %d", tc.name, tc.maxExpectedStreams, streamSettings) @@ -317,7 +323,7 @@ func TestIntegration_StorageReadQueryStruct(t *testing.T) { total++ } - bqSession := it.arrowIterator.session.bqSession + bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession if len(bqSession.Streams) == 0 { t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams)) } @@ -366,7 +372,7 @@ func TestIntegration_StorageReadQueryMorePages(t *testing.T) { } total++ // as we read the first value separately - bqSession := it.arrowIterator.session.bqSession + bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession if len(bqSession.Streams) == 0 { t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams)) } @@ -418,11 +424,88 @@ func TestIntegration_StorageReadCancel(t *testing.T) { } // resources are cleaned asynchronously time.Sleep(time.Second) - if !it.arrowIterator.isDone() { + arrowIt := it.arrowIterator.(*storageArrowIterator) + if !arrowIt.isDone() { t.Fatal("expected stream to be done") } } +func TestIntegration_StorageReadArrow(t *testing.T) { + if client == nil { + t.Skip("Integration tests skipped") + } + ctx := context.Background() + table := "`bigquery-public-data.usa_names.usa_1910_current`" + sql := fmt.Sprintf(`SELECT name, number, state FROM %s where state = "CA"`, table) + + q := storageOptimizedClient.Query(sql) + job, err := q.Run(ctx) // force usage of Storage API by skipping fast paths + if err != nil { + t.Fatal(err) + } + it, err := job.Read(ctx) + if err != nil { + t.Fatal(err) + } + + checkedAllocator := memory.NewCheckedAllocator(memory.DefaultAllocator) + it.arrowDecoder.allocator = checkedAllocator + defer checkedAllocator.AssertSize(t, 0) + + arrowIt, err := it.ArrowIterator() + if err != nil { + t.Fatalf("expected iterator to be accelerated: %v", err) + } + arrowItReader := NewArrowIteratorReader(arrowIt) + + records := []arrow.Record{} + r, err := ipc.NewReader(arrowItReader, ipc.WithAllocator(checkedAllocator)) + numrec := 0 + for r.Next() { + rec := r.Record() + rec.Retain() + defer rec.Release() + records = append(records, rec) + numrec += int(rec.NumRows()) + } + r.Release() + + arrowSchema := r.Schema() + arrowTable := array.NewTableFromRecords(arrowSchema, records) + defer arrowTable.Release() + if arrowTable.NumRows() != int64(it.TotalRows) { + t.Fatalf("should have a table with %d rows, but found %d", it.TotalRows, arrowTable.NumRows()) + } + if arrowTable.NumCols() != 3 { + t.Fatalf("should have a table with 3 columns, but found %d", arrowTable.NumCols()) + } + + sumSQL := fmt.Sprintf(`SELECT sum(number) as total FROM %s where state = "CA"`, table) + sumQuery := client.Query(sumSQL) + sumIt, err := sumQuery.Read(ctx) + if err != nil { + t.Fatal(err) + } + sumValues := []Value{} + err = sumIt.Next(&sumValues) + if err != nil { + t.Fatal(err) + } + totalFromSQL := sumValues[0].(int64) + + tr := array.NewTableReader(arrowTable, arrowTable.NumRows()) + defer tr.Release() + var totalFromArrow int64 + for tr.Next() { + rec := tr.Record() + vec := rec.Column(1).(*array.Int64) + totalFromArrow += math.Int64.Sum(vec) + } + if totalFromArrow != totalFromSQL { + t.Fatalf("expected total to be %d, but with arrow we got %d", totalFromSQL, totalFromArrow) + } +} + func countIteratorRows(it *RowIterator) (total uint64, err error) { for { var dst []Value diff --git a/bigquery/storage_iterator.go b/bigquery/storage_iterator.go index ae5ae2a66d17..96e12f869a77 100644 --- a/bigquery/storage_iterator.go +++ b/bigquery/storage_iterator.go @@ -32,20 +32,21 @@ import ( "google.golang.org/grpc/status" ) -// arrowIterator is a raw interface for getting data from Storage Read API -type arrowIterator struct { - done uint32 // atomic flag - errs chan error - ctx context.Context +// storageArrowIterator is a raw interface for getting data from Storage Read API +type storageArrowIterator struct { + done uint32 // atomic flag + initialized bool + errs chan error + ctx context.Context - schema Schema - decoder *arrowDecoder - records chan arrowRecordBatch + schema Schema + rawSchema []byte + records chan *ArrowRecordBatch session *readSession } -type arrowRecordBatch []byte +var _ ArrowIterator = &storageArrowIterator{} func newStorageRowIteratorFromTable(ctx context.Context, table *Table, ordered bool) (*RowIterator, error) { md, err := table.Metadata(ctx) @@ -56,11 +57,19 @@ func newStorageRowIteratorFromTable(ctx context.Context, table *Table, ordered b if err != nil { return nil, err } - it, err := newStorageRowIterator(rs) + it, err := newStorageRowIterator(rs, md.Schema) if err != nil { return nil, err } - it.arrowIterator.schema = md.Schema + if rs.bqSession == nil { + return nil, errors.New("read session not initialized") + } + arrowSerializedSchema := rs.bqSession.GetArrowSchema().GetSerializedSchema() + dec, err := newArrowDecoder(arrowSerializedSchema, md.Schema) + if err != nil { + return nil, err + } + it.arrowDecoder = dec it.Schema = md.Schema return it, nil } @@ -112,11 +121,12 @@ func resolveLastChildSelectJob(ctx context.Context, job *Job) (*Job, error) { return childJobs[0], nil } -func newRawStorageRowIterator(rs *readSession) (*arrowIterator, error) { - arrowIt := &arrowIterator{ +func newRawStorageRowIterator(rs *readSession, schema Schema) (*storageArrowIterator, error) { + arrowIt := &storageArrowIterator{ ctx: rs.ctx, session: rs, - records: make(chan arrowRecordBatch, rs.settings.maxWorkerCount+1), + schema: schema, + records: make(chan *ArrowRecordBatch, rs.settings.maxWorkerCount+1), errs: make(chan error, rs.settings.maxWorkerCount+1), } if rs.bqSession == nil { @@ -125,11 +135,12 @@ func newRawStorageRowIterator(rs *readSession) (*arrowIterator, error) { return nil, err } } + arrowIt.rawSchema = rs.bqSession.GetArrowSchema().GetSerializedSchema() return arrowIt, nil } -func newStorageRowIterator(rs *readSession) (*RowIterator, error) { - arrowIt, err := newRawStorageRowIterator(rs) +func newStorageRowIterator(rs *readSession, schema Schema) (*RowIterator, error) { + arrowIt, err := newRawStorageRowIterator(rs, schema) if err != nil { return nil, err } @@ -153,8 +164,7 @@ func nextFuncForStorageIterator(it *RowIterator) func() error { if len(it.rows) > 0 { return nil } - arrowIt := it.arrowIterator - record, err := arrowIt.next() + record, err := it.arrowIterator.Next() if err == iterator.Done { if len(it.rows) == 0 { return iterator.Done @@ -165,9 +175,9 @@ func nextFuncForStorageIterator(it *RowIterator) func() error { return err } if it.Schema == nil { - it.Schema = it.arrowIterator.schema + it.Schema = it.arrowIterator.Schema() } - rows, err := arrowIt.decoder.decodeArrowRecords(record) + rows, err := it.arrowDecoder.decodeArrowRecords(record) if err != nil { return err } @@ -176,8 +186,8 @@ func nextFuncForStorageIterator(it *RowIterator) func() error { } } -func (it *arrowIterator) init() error { - if it.decoder != nil { // Already initialized +func (it *storageArrowIterator) init() error { + if it.initialized { return nil } @@ -191,20 +201,6 @@ func (it *arrowIterator) init() error { return iterator.Done } - if it.schema == nil { - meta, err := it.session.table.Metadata(it.ctx) - if err != nil { - return err - } - it.schema = meta.Schema - } - - decoder, err := newArrowDecoderFromSession(it.session, it.schema) - if err != nil { - return err - } - it.decoder = decoder - wg := sync.WaitGroup{} wg.Add(len(streams)) sem := semaphore.NewWeighted(int64(it.session.settings.maxWorkerCount)) @@ -229,18 +225,19 @@ func (it *arrowIterator) init() error { }(readStream.Name) } }() + it.initialized = true return nil } -func (it *arrowIterator) markDone() { +func (it *storageArrowIterator) markDone() { atomic.StoreUint32(&it.done, 1) } -func (it *arrowIterator) isDone() bool { +func (it *storageArrowIterator) isDone() bool { return atomic.LoadUint32(&it.done) != 0 } -func (it *arrowIterator) processStream(readStream string) { +func (it *storageArrowIterator) processStream(readStream string) { bo := gax.Backoff{} var offset int64 for { @@ -300,7 +297,7 @@ func retryReadRows(bo gax.Backoff, err error) (time.Duration, bool) { return bo.Pause(), false } -func (it *arrowIterator) consumeRowStream(readStream string, rowStream storagepb.BigQueryRead_ReadRowsClient, offset int64) (int64, error) { +func (it *storageArrowIterator) consumeRowStream(readStream string, rowStream storagepb.BigQueryRead_ReadRowsClient, offset int64) (int64, error) { for { r, err := rowStream.Recv() if err != nil { @@ -311,8 +308,12 @@ func (it *arrowIterator) consumeRowStream(readStream string, rowStream storagepb } if r.RowCount > 0 { offset += r.RowCount - arrowRecordBatch := r.GetArrowRecordBatch() - it.records <- arrowRecordBatch.SerializedRecordBatch + recordBatch := r.GetArrowRecordBatch() + it.records <- &ArrowRecordBatch{ + PartitionID: readStream, + Schema: it.rawSchema, + Data: recordBatch.SerializedRecordBatch, + } } } } @@ -320,7 +321,7 @@ func (it *arrowIterator) consumeRowStream(readStream string, rowStream storagepb // next return the next batch of rows as an arrow.Record. // Accessing Arrow Records directly has the drawnback of having to deal // with memory management. -func (it *arrowIterator) next() (arrowRecordBatch, error) { +func (it *storageArrowIterator) Next() (*ArrowRecordBatch, error) { if err := it.init(); err != nil { return nil, err } @@ -343,8 +344,28 @@ func (it *arrowIterator) next() (arrowRecordBatch, error) { } } +func (it *storageArrowIterator) SerializedArrowSchema() []byte { + return it.rawSchema +} + +func (it *storageArrowIterator) Schema() Schema { + return it.schema +} + // IsAccelerated check if the current RowIterator is // being accelerated by Storage API. func (it *RowIterator) IsAccelerated() bool { return it.arrowIterator != nil } + +// ArrowIterator gives access to the raw Arrow Record Batch stream to be consumed directly. +// Experimental: this interface is experimental and may be modified or removed in future versions, +// regardless of any other documented package stability guarantees. +// Don't try to mix RowIterator.Next and ArrowIterator.Next calls. +func (it *RowIterator) ArrowIterator() (ArrowIterator, error) { + if !it.IsAccelerated() { + // TODO: can we convert plain RowIterator based on JSON API to an Arrow Stream ? + return nil, errors.New("bigquery: require storage read API to be enabled") + } + return it.arrowIterator, nil +} diff --git a/bigquery/storage_iterator_test.go b/bigquery/storage_iterator_test.go index 75e0d87973a1..10471184f544 100644 --- a/bigquery/storage_iterator_test.go +++ b/bigquery/storage_iterator_test.go @@ -110,7 +110,7 @@ func TestStorageIteratorRetry(t *testing.T) { settings: defaultReadClientSettings(), readRowsFunc: readRowsFunc, bqSession: &storagepb.ReadSession{}, - }) + }, Schema{}) if err != nil { t.Fatalf("case %s: newRawStorageRowIterator: %v", tc.desc, err) }