From 19a175dc27a17683a6ca075249e0dd2a1e3d9a62 Mon Sep 17 00:00:00 2001 From: "e.sevastyanov" Date: Fri, 16 Feb 2024 13:16:07 +0400 Subject: [PATCH 1/3] Parallel fetch of batches in Snowflake connector --- runtime/drivers/snowflake/sql_store.go | 94 ++++++++++++++++++++++---- 1 file changed, 81 insertions(+), 13 deletions(-) diff --git a/runtime/drivers/snowflake/sql_store.go b/runtime/drivers/snowflake/sql_store.go index 7d3bbd4974b..687e39151dc 100644 --- a/runtime/drivers/snowflake/sql_store.go +++ b/runtime/drivers/snowflake/sql_store.go @@ -7,6 +7,8 @@ import ( "fmt" "io" "os" + "strconv" + "sync" "time" "github.com/apache/arrow/go/v14/arrow" @@ -20,6 +22,7 @@ import ( "github.com/rilldata/rill/runtime/pkg/observability" sf "github.com/snowflakedb/gosnowflake" "go.uber.org/zap" + "golang.org/x/sync/semaphore" ) // recommended size is 512MB - 1GB, entire data is buffered in memory before its written to disk @@ -49,6 +52,20 @@ func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any, opt return nil, fmt.Errorf("the property 'dsn' is required for Snowflake. Provide 'dsn' in the YAML properties or pass '--var connector.snowflake.dsn=...' to 'rill start'") } + // Parse dsn and extract parallelFetchLimit + parallelFetchLimit := 10 + dsnConfig, err := sf.ParseDSN(dsn) + if err != nil { + return nil, err + } + fetchLimitPtr, exists := dsnConfig.Params["parallelFetchLimit"] + if exists && fetchLimitPtr != nil { + v, err := strconv.Atoi(*fetchLimitPtr) + if err == nil { + parallelFetchLimit = v + } + } + db, err := sql.Open("snowflake", dsn) if err != nil { return nil, err @@ -87,14 +104,15 @@ func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any, opt p.Target(1, drivers.ProgressUnitFile) return &fileIterator{ - ctx: ctx, - db: db, - conn: conn, - rows: rows, - batches: batches, - progress: p, - limitInBytes: opt.TotalLimitInBytes, - logger: c.logger, + ctx: ctx, + db: db, + conn: conn, + rows: rows, + batches: batches, + progress: p, + limitInBytes: opt.TotalLimitInBytes, + parallelFetchLimit: parallelFetchLimit, + logger: c.logger, }, nil } @@ -111,6 +129,8 @@ type fileIterator struct { totalRecords int64 tempFilePath string downloaded bool + // Max number of batches to fetch in parallel + parallelFetchLimit int } // Close implements drivers.FileIterator. @@ -184,13 +204,47 @@ func (f *fileIterator) Next() ([]string, error) { // write arrow records to parquet file // the following iteration might be memory intensive // since batches are organized as a slice and every batch caches its content + batchesLeft := len(f.batches) + f.logger.Debug("starting to fetch and process arrow batches", + zap.Int("batches", batchesLeft), zap.Int("parallel_fetch_limit", f.parallelFetchLimit)) + + // Fetch batches async as it takes most of the time + var wg sync.WaitGroup + fetchResultChan := make(chan fetchResult, len(f.batches)) + sem := semaphore.NewWeighted(int64(f.parallelFetchLimit)) for _, batch := range f.batches { - records, err := batch.Fetch() - if err != nil { - return nil, err + wg.Add(1) + go func(b *sf.ArrowBatch) { + defer wg.Done() + defer sem.Release(1) + err := sem.Acquire(f.ctx, 1) + if err != nil { + fetchResultChan <- fetchResult{Records: nil, Batch: nil, Err: err} + return + } + fetchStart := time.Now() + records, err := b.Fetch() + fetchResultChan <- fetchResult{Records: records, Batch: b, Err: err} + f.logger.Debug( + "fetched an arrow batch", + zap.Duration("duration", time.Since(fetchStart)), + zap.Int("row_count", b.GetRowCount()), + ) + }(batch) + } + + go func() { + wg.Wait() + close(fetchResultChan) + }() + + for result := range fetchResultChan { + if result.Err != nil { + return nil, result.Err } - f.totalRecords += int64(batch.GetRowCount()) - for _, rec := range *records { + batch := result.Batch + writeStart := time.Now() + for _, rec := range *result.Records { if writer.RowGroupTotalBytesWritten() >= rowGroupBufferSize { writer.NewBufferedRowGroup() } @@ -204,6 +258,14 @@ func (f *fileIterator) Next() ([]string, error) { } } } + batchesLeft-- + f.logger.Debug( + "wrote an arrow batch to a parquet file", + zap.Float64("progress", float64(len(f.batches)-batchesLeft)/float64(len(f.batches))*100), + zap.Int("row_count", batch.GetRowCount()), + zap.Duration("write_duration", time.Since(writeStart)), + ) + f.totalRecords += int64(result.Batch.GetRowCount()) } writer.Close() @@ -237,6 +299,12 @@ func (f *fileIterator) Format() string { var _ drivers.FileIterator = &fileIterator{} +type fetchResult struct { + Records *[]arrow.Record + Batch *sf.ArrowBatch + Err error +} + type sourceProperties struct { SQL string `mapstructure:"sql"` DSN string `mapstructure:"dsn"` From 03a94e8260aeb31901cefb61fed3089b27c1bfe6 Mon Sep 17 00:00:00 2001 From: "e.sevastyanov" Date: Fri, 16 Feb 2024 22:46:59 +0400 Subject: [PATCH 2/3] Errgroups for fetching and writing --- runtime/drivers/snowflake/sql_store.go | 138 +++++++++++++------------ 1 file changed, 73 insertions(+), 65 deletions(-) diff --git a/runtime/drivers/snowflake/sql_store.go b/runtime/drivers/snowflake/sql_store.go index 687e39151dc..e662b6faa6c 100644 --- a/runtime/drivers/snowflake/sql_store.go +++ b/runtime/drivers/snowflake/sql_store.go @@ -8,7 +8,6 @@ import ( "io" "os" "strconv" - "sync" "time" "github.com/apache/arrow/go/v14/arrow" @@ -22,7 +21,7 @@ import ( "github.com/rilldata/rill/runtime/pkg/observability" sf "github.com/snowflakedb/gosnowflake" "go.uber.org/zap" - "golang.org/x/sync/semaphore" + "golang.org/x/sync/errgroup" ) // recommended size is 512MB - 1GB, entire data is buffered in memory before its written to disk @@ -52,17 +51,11 @@ func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any, opt return nil, fmt.Errorf("the property 'dsn' is required for Snowflake. Provide 'dsn' in the YAML properties or pass '--var connector.snowflake.dsn=...' to 'rill start'") } - // Parse dsn and extract parallelFetchLimit - parallelFetchLimit := 10 - dsnConfig, err := sf.ParseDSN(dsn) - if err != nil { - return nil, err - } - fetchLimitPtr, exists := dsnConfig.Params["parallelFetchLimit"] - if exists && fetchLimitPtr != nil { - v, err := strconv.Atoi(*fetchLimitPtr) - if err == nil { - parallelFetchLimit = v + parallelFetchLimit := 15 + if limit, ok := c.config["parallel_fetch_limit"].(string); ok { + parallelFetchLimit, err = strconv.Atoi(limit) + if err != nil { + return nil, err } } @@ -204,68 +197,84 @@ func (f *fileIterator) Next() ([]string, error) { // write arrow records to parquet file // the following iteration might be memory intensive // since batches are organized as a slice and every batch caches its content - batchesLeft := len(f.batches) f.logger.Debug("starting to fetch and process arrow batches", - zap.Int("batches", batchesLeft), zap.Int("parallel_fetch_limit", f.parallelFetchLimit)) + zap.Int("batches", len(f.batches)), zap.Int("parallel_fetch_limit", f.parallelFetchLimit)) + + // Fetch batches async + fetchGrp, ctx := errgroup.WithContext(f.ctx) + fetchGrp.SetLimit(f.parallelFetchLimit) + fetchResultChan := make(chan fetchResult) + + // Write batches into a file async + writeGrp, _ := errgroup.WithContext(f.ctx) + writeGrp.Go(func() error { + batchesLeft := len(f.batches) + for { + select { + case result, ok := <-fetchResultChan: + if !ok { + return nil + } + batch := result.batch + writeStart := time.Now() + for _, rec := range *result.records { + if writer.RowGroupTotalBytesWritten() >= rowGroupBufferSize { + writer.NewBufferedRowGroup() + } + if err := writer.WriteBuffered(rec); err != nil { + return err + } + fileInfo, err := os.Stat(fw.Name()) + if err == nil { // ignore error + if fileInfo.Size() > f.limitInBytes { + return drivers.ErrStorageLimitExceeded + } + } + } + batchesLeft-- + f.logger.Debug( + "wrote an arrow batch to a parquet file", + zap.Float64("progress", float64(len(f.batches)-batchesLeft)/float64(len(f.batches))*100), + zap.Int("row_count", batch.GetRowCount()), + zap.Duration("write_duration", time.Since(writeStart)), + ) + f.totalRecords += int64(result.batch.GetRowCount()) + case <-ctx.Done(): + if ctx.Err() != nil { + return nil + } + } + } + }) - // Fetch batches async as it takes most of the time - var wg sync.WaitGroup - fetchResultChan := make(chan fetchResult, len(f.batches)) - sem := semaphore.NewWeighted(int64(f.parallelFetchLimit)) for _, batch := range f.batches { - wg.Add(1) - go func(b *sf.ArrowBatch) { - defer wg.Done() - defer sem.Release(1) - err := sem.Acquire(f.ctx, 1) - if err != nil { - fetchResultChan <- fetchResult{Records: nil, Batch: nil, Err: err} - return - } + b := batch + fetchGrp.Go(func() error { fetchStart := time.Now() records, err := b.Fetch() - fetchResultChan <- fetchResult{Records: records, Batch: b, Err: err} + if err != nil { + return err + } + fetchResultChan <- fetchResult{records: records, batch: b} f.logger.Debug( "fetched an arrow batch", zap.Duration("duration", time.Since(fetchStart)), zap.Int("row_count", b.GetRowCount()), ) - }(batch) + return nil + }) } - go func() { - wg.Wait() - close(fetchResultChan) - }() + err = fetchGrp.Wait() + ctx.Done() + close(fetchResultChan) - for result := range fetchResultChan { - if result.Err != nil { - return nil, result.Err - } - batch := result.Batch - writeStart := time.Now() - for _, rec := range *result.Records { - if writer.RowGroupTotalBytesWritten() >= rowGroupBufferSize { - writer.NewBufferedRowGroup() - } - if err := writer.WriteBuffered(rec); err != nil { - return nil, err - } - fileInfo, err := os.Stat(fw.Name()) - if err == nil { // ignore error - if fileInfo.Size() > f.limitInBytes { - return nil, drivers.ErrStorageLimitExceeded - } - } - } - batchesLeft-- - f.logger.Debug( - "wrote an arrow batch to a parquet file", - zap.Float64("progress", float64(len(f.batches)-batchesLeft)/float64(len(f.batches))*100), - zap.Int("row_count", batch.GetRowCount()), - zap.Duration("write_duration", time.Since(writeStart)), - ) - f.totalRecords += int64(result.Batch.GetRowCount()) + if err != nil { + return nil, err + } + + if err := writeGrp.Wait(); err != nil { + return nil, err } writer.Close() @@ -300,9 +309,8 @@ func (f *fileIterator) Format() string { var _ drivers.FileIterator = &fileIterator{} type fetchResult struct { - Records *[]arrow.Record - Batch *sf.ArrowBatch - Err error + records *[]arrow.Record + batch *sf.ArrowBatch } type sourceProperties struct { From f61b300b052a4af9a0902d6fefe00780b9a7151b Mon Sep 17 00:00:00 2001 From: "e.sevastyanov" Date: Mon, 19 Feb 2024 17:24:21 +0400 Subject: [PATCH 3/3] Fixed context cancellation case --- runtime/drivers/snowflake/sql_store.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/runtime/drivers/snowflake/sql_store.go b/runtime/drivers/snowflake/sql_store.go index e662b6faa6c..a1c59eac455 100644 --- a/runtime/drivers/snowflake/sql_store.go +++ b/runtime/drivers/snowflake/sql_store.go @@ -240,9 +240,7 @@ func (f *fileIterator) Next() ([]string, error) { ) f.totalRecords += int64(result.batch.GetRowCount()) case <-ctx.Done(): - if ctx.Err() != nil { - return nil - } + return ctx.Err() } } }) @@ -266,7 +264,6 @@ func (f *fileIterator) Next() ([]string, error) { } err = fetchGrp.Wait() - ctx.Done() close(fetchResultChan) if err != nil {