diff --git a/chunk_downloader.go b/chunk_downloader.go index 7da5c62ba..f715992c0 100644 --- a/chunk_downloader.go +++ b/chunk_downloader.go @@ -707,6 +707,13 @@ type ArrowBatch struct { rowCount int scd *snowflakeChunkDownloader funcDownloadHelper func(context.Context, *snowflakeChunkDownloader, int) error + ctx context.Context +} + +// WithContext sets the context which will be used for this ArrowBatch. +func (rb *ArrowBatch) WithContext(ctx context.Context) *ArrowBatch { + rb.ctx = ctx + return rb } // Fetch returns an array of records representing a chunk in the query @@ -717,7 +724,13 @@ func (rb *ArrowBatch) Fetch() (*[]arrow.Record, error) { rb.rowCount = countArrowBatchRows(rb.rec) return rb.rec, nil } - if err := rb.funcDownloadHelper(context.Background(), rb.scd, rb.idx); err != nil { + var ctx context.Context + if rb.ctx != nil { + ctx = rb.ctx + } else { + ctx = context.Background() + } + if err := rb.funcDownloadHelper(ctx, rb.scd, rb.idx); err != nil { return nil, err } return rb.rec, nil