diff --git a/runtime/drivers/snowflake/sql_store.go b/runtime/drivers/snowflake/sql_store.go index 7d3bbd4974b..a1c59eac455 100644 --- a/runtime/drivers/snowflake/sql_store.go +++ b/runtime/drivers/snowflake/sql_store.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "os" + "strconv" "time" "github.com/apache/arrow/go/v14/arrow" @@ -20,6 +21,7 @@ import ( "github.com/rilldata/rill/runtime/pkg/observability" sf "github.com/snowflakedb/gosnowflake" "go.uber.org/zap" + "golang.org/x/sync/errgroup" ) // recommended size is 512MB - 1GB, entire data is buffered in memory before its written to disk @@ -49,6 +51,14 @@ func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any, opt return nil, fmt.Errorf("the property 'dsn' is required for Snowflake. Provide 'dsn' in the YAML properties or pass '--var connector.snowflake.dsn=...' to 'rill start'") } + parallelFetchLimit := 15 + if limit, ok := c.config["parallel_fetch_limit"].(string); ok { + parallelFetchLimit, err = strconv.Atoi(limit) + if err != nil { + return nil, err + } + } + db, err := sql.Open("snowflake", dsn) if err != nil { return nil, err @@ -87,14 +97,15 @@ func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any, opt p.Target(1, drivers.ProgressUnitFile) return &fileIterator{ - ctx: ctx, - db: db, - conn: conn, - rows: rows, - batches: batches, - progress: p, - limitInBytes: opt.TotalLimitInBytes, - logger: c.logger, + ctx: ctx, + db: db, + conn: conn, + rows: rows, + batches: batches, + progress: p, + limitInBytes: opt.TotalLimitInBytes, + parallelFetchLimit: parallelFetchLimit, + logger: c.logger, }, nil } @@ -111,6 +122,8 @@ type fileIterator struct { totalRecords int64 tempFilePath string downloaded bool + // Max number of batches to fetch in parallel + parallelFetchLimit int } // Close implements drivers.FileIterator. @@ -184,26 +197,81 @@ func (f *fileIterator) Next() ([]string, error) { // write arrow records to parquet file // the following iteration might be memory intensive // since batches are organized as a slice and every batch caches its content - for _, batch := range f.batches { - records, err := batch.Fetch() - if err != nil { - return nil, err - } - f.totalRecords += int64(batch.GetRowCount()) - for _, rec := range *records { - if writer.RowGroupTotalBytesWritten() >= rowGroupBufferSize { - writer.NewBufferedRowGroup() - } - if err := writer.WriteBuffered(rec); err != nil { - return nil, err - } - fileInfo, err := os.Stat(fw.Name()) - if err == nil { // ignore error - if fileInfo.Size() > f.limitInBytes { - return nil, drivers.ErrStorageLimitExceeded + f.logger.Debug("starting to fetch and process arrow batches", + zap.Int("batches", len(f.batches)), zap.Int("parallel_fetch_limit", f.parallelFetchLimit)) + + // Fetch batches async + fetchGrp, ctx := errgroup.WithContext(f.ctx) + fetchGrp.SetLimit(f.parallelFetchLimit) + fetchResultChan := make(chan fetchResult) + + // Write batches into a file async + writeGrp, _ := errgroup.WithContext(f.ctx) + writeGrp.Go(func() error { + batchesLeft := len(f.batches) + for { + select { + case result, ok := <-fetchResultChan: + if !ok { + return nil + } + batch := result.batch + writeStart := time.Now() + for _, rec := range *result.records { + if writer.RowGroupTotalBytesWritten() >= rowGroupBufferSize { + writer.NewBufferedRowGroup() + } + if err := writer.WriteBuffered(rec); err != nil { + return err + } + fileInfo, err := os.Stat(fw.Name()) + if err == nil { // ignore error + if fileInfo.Size() > f.limitInBytes { + return drivers.ErrStorageLimitExceeded + } + } } + batchesLeft-- + f.logger.Debug( + "wrote an arrow batch to a parquet file", + zap.Float64("progress", float64(len(f.batches)-batchesLeft)/float64(len(f.batches))*100), + zap.Int("row_count", batch.GetRowCount()), + zap.Duration("write_duration", time.Since(writeStart)), + ) + f.totalRecords += int64(result.batch.GetRowCount()) + case <-ctx.Done(): + return ctx.Err() } } + }) + + for _, batch := range f.batches { + b := batch + fetchGrp.Go(func() error { + fetchStart := time.Now() + records, err := b.Fetch() + if err != nil { + return err + } + fetchResultChan <- fetchResult{records: records, batch: b} + f.logger.Debug( + "fetched an arrow batch", + zap.Duration("duration", time.Since(fetchStart)), + zap.Int("row_count", b.GetRowCount()), + ) + return nil + }) + } + + err = fetchGrp.Wait() + close(fetchResultChan) + + if err != nil { + return nil, err + } + + if err := writeGrp.Wait(); err != nil { + return nil, err } writer.Close() @@ -237,6 +305,11 @@ func (f *fileIterator) Format() string { var _ drivers.FileIterator = &fileIterator{} +type fetchResult struct { + records *[]arrow.Record + batch *sf.ArrowBatch +} + type sourceProperties struct { SQL string `mapstructure:"sql"` DSN string `mapstructure:"dsn"`