Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel fetch of batches in Snowflake connector #4070

Merged
merged 3 commits into from
Feb 19, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 81 additions & 13 deletions runtime/drivers/snowflake/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"fmt"
"io"
"os"
"strconv"
"sync"
"time"

"github.com/apache/arrow/go/v14/arrow"
Expand All @@ -20,6 +22,7 @@ import (
"github.com/rilldata/rill/runtime/pkg/observability"
sf "github.com/snowflakedb/gosnowflake"
"go.uber.org/zap"
"golang.org/x/sync/semaphore"
)

// recommended size is 512MB - 1GB, entire data is buffered in memory before its written to disk
Expand Down Expand Up @@ -49,6 +52,20 @@ func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any, opt
return nil, fmt.Errorf("the property 'dsn' is required for Snowflake. Provide 'dsn' in the YAML properties or pass '--var connector.snowflake.dsn=...' to 'rill start'")
}

// Parse dsn and extract parallelFetchLimit
parallelFetchLimit := 10
dsnConfig, err := sf.ParseDSN(dsn)
if err != nil {
return nil, err
}
fetchLimitPtr, exists := dsnConfig.Params["parallelFetchLimit"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a standard Snowflake property? If not, I think it would be better to set it as a separate connector property – i.e. configured using --var connector.snowflake.parallel_fetch_limit=N.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it isn't a Snowflake property
Moved to env var

-env connector.snowflake.parallel_fetch_limit=20

if exists && fetchLimitPtr != nil {
v, err := strconv.Atoi(*fetchLimitPtr)
if err == nil {
parallelFetchLimit = v
}
}

db, err := sql.Open("snowflake", dsn)
if err != nil {
return nil, err
Expand Down Expand Up @@ -87,14 +104,15 @@ func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any, opt
p.Target(1, drivers.ProgressUnitFile)

return &fileIterator{
ctx: ctx,
db: db,
conn: conn,
rows: rows,
batches: batches,
progress: p,
limitInBytes: opt.TotalLimitInBytes,
logger: c.logger,
ctx: ctx,
db: db,
conn: conn,
rows: rows,
batches: batches,
progress: p,
limitInBytes: opt.TotalLimitInBytes,
parallelFetchLimit: parallelFetchLimit,
logger: c.logger,
}, nil
}

Expand All @@ -111,6 +129,8 @@ type fileIterator struct {
totalRecords int64
tempFilePath string
downloaded bool
// Max number of batches to fetch in parallel
parallelFetchLimit int
}

// Close implements drivers.FileIterator.
Expand Down Expand Up @@ -184,13 +204,47 @@ func (f *fileIterator) Next() ([]string, error) {
// write arrow records to parquet file
// the following iteration might be memory intensive
// since batches are organized as a slice and every batch caches its content
batchesLeft := len(f.batches)
f.logger.Debug("starting to fetch and process arrow batches",
zap.Int("batches", batchesLeft), zap.Int("parallel_fetch_limit", f.parallelFetchLimit))

// Fetch batches async as it takes most of the time
var wg sync.WaitGroup
fetchResultChan := make(chan fetchResult, len(f.batches))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the max number of batches? Wondering if this might take a lot of memory. If the writing phase is fast, I think we should avoid a buffered channel here. See: https://github.com/uber-go/guide/blob/master/style.md#channel-size-is-one-or-none

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no way to control the number of batches and number of records per batch in Snowflake. There might be 1K batches for 100M rows

sem := semaphore.NewWeighted(int64(f.parallelFetchLimit))
for _, batch := range f.batches {
records, err := batch.Fetch()
if err != nil {
return nil, err
wg.Add(1)
go func(b *sf.ArrowBatch) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question – how many batches might there be? Starting 10s-100s of Goroutines is okay, but probably worth optimizing if looking at thousands.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it might be simpler to use an errgroup with WithContext and SetLimit. And then doing the writing from a separate goroutine. It would avoid the semaphore, avoid potentially 100s of concurrent goroutines, and make error propagation / cancellation easier. Something like this:

grp, ctx := errgroup.WithContext(ctx)
grp.SetLimit(f.parallelFetchLimit)

go func() {
  for {
    select {
    case res, ok <-resultChan:
      // Write...
    case <-ctx.Done():
      // Either finished or errored
      return
    }
  }
}()

for _, batch := range f.batches {
  grp.Go(...)
}

err := wg.Wait()
// done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converted using errgroup. The writing also requires an error propagation so used errgroup the the writing too.

defer wg.Done()
defer sem.Release(1)
err := sem.Acquire(f.ctx, 1)
if err != nil {
fetchResultChan <- fetchResult{Records: nil, Batch: nil, Err: err}
return
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think releases should be after successful acquire (else might panic)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the use of a semaphore

fetchStart := time.Now()
records, err := b.Fetch()
fetchResultChan <- fetchResult{Records: records, Batch: b, Err: err}
f.logger.Debug(
"fetched an arrow batch",
zap.Duration("duration", time.Since(fetchStart)),
zap.Int("row_count", b.GetRowCount()),
)
}(batch)
}

go func() {
wg.Wait()
close(fetchResultChan)
}()

for result := range fetchResultChan {
if result.Err != nil {
return nil, result.Err
}
f.totalRecords += int64(batch.GetRowCount())
for _, rec := range *records {
batch := result.Batch
writeStart := time.Now()
for _, rec := range *result.Records {
if writer.RowGroupTotalBytesWritten() >= rowGroupBufferSize {
writer.NewBufferedRowGroup()
}
Expand All @@ -204,6 +258,14 @@ func (f *fileIterator) Next() ([]string, error) {
}
}
}
batchesLeft--
f.logger.Debug(
"wrote an arrow batch to a parquet file",
zap.Float64("progress", float64(len(f.batches)-batchesLeft)/float64(len(f.batches))*100),
zap.Int("row_count", batch.GetRowCount()),
zap.Duration("write_duration", time.Since(writeStart)),
)
f.totalRecords += int64(result.Batch.GetRowCount())
}

writer.Close()
Expand Down Expand Up @@ -237,6 +299,12 @@ func (f *fileIterator) Format() string {

var _ drivers.FileIterator = &fileIterator{}

type fetchResult struct {
Records *[]arrow.Record
Batch *sf.ArrowBatch
Err error
}

type sourceProperties struct {
SQL string `mapstructure:"sql"`
DSN string `mapstructure:"dsn"`
Expand Down
Loading