-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallel fetch of batches in Snowflake connector #4070
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,8 @@ import ( | |
"fmt" | ||
"io" | ||
"os" | ||
"strconv" | ||
"sync" | ||
"time" | ||
|
||
"github.com/apache/arrow/go/v14/arrow" | ||
|
@@ -20,6 +22,7 @@ import ( | |
"github.com/rilldata/rill/runtime/pkg/observability" | ||
sf "github.com/snowflakedb/gosnowflake" | ||
"go.uber.org/zap" | ||
"golang.org/x/sync/semaphore" | ||
) | ||
|
||
// recommended size is 512MB - 1GB, entire data is buffered in memory before its written to disk | ||
|
@@ -49,6 +52,20 @@ func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any, opt | |
return nil, fmt.Errorf("the property 'dsn' is required for Snowflake. Provide 'dsn' in the YAML properties or pass '--var connector.snowflake.dsn=...' to 'rill start'") | ||
} | ||
|
||
// Parse dsn and extract parallelFetchLimit | ||
parallelFetchLimit := 10 | ||
dsnConfig, err := sf.ParseDSN(dsn) | ||
if err != nil { | ||
return nil, err | ||
} | ||
fetchLimitPtr, exists := dsnConfig.Params["parallelFetchLimit"] | ||
if exists && fetchLimitPtr != nil { | ||
v, err := strconv.Atoi(*fetchLimitPtr) | ||
if err == nil { | ||
parallelFetchLimit = v | ||
} | ||
} | ||
|
||
db, err := sql.Open("snowflake", dsn) | ||
if err != nil { | ||
return nil, err | ||
|
@@ -87,14 +104,15 @@ func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any, opt | |
p.Target(1, drivers.ProgressUnitFile) | ||
|
||
return &fileIterator{ | ||
ctx: ctx, | ||
db: db, | ||
conn: conn, | ||
rows: rows, | ||
batches: batches, | ||
progress: p, | ||
limitInBytes: opt.TotalLimitInBytes, | ||
logger: c.logger, | ||
ctx: ctx, | ||
db: db, | ||
conn: conn, | ||
rows: rows, | ||
batches: batches, | ||
progress: p, | ||
limitInBytes: opt.TotalLimitInBytes, | ||
parallelFetchLimit: parallelFetchLimit, | ||
logger: c.logger, | ||
}, nil | ||
} | ||
|
||
|
@@ -111,6 +129,8 @@ type fileIterator struct { | |
totalRecords int64 | ||
tempFilePath string | ||
downloaded bool | ||
// Max number of batches to fetch in parallel | ||
parallelFetchLimit int | ||
} | ||
|
||
// Close implements drivers.FileIterator. | ||
|
@@ -184,13 +204,47 @@ func (f *fileIterator) Next() ([]string, error) { | |
// write arrow records to parquet file | ||
// the following iteration might be memory intensive | ||
// since batches are organized as a slice and every batch caches its content | ||
batchesLeft := len(f.batches) | ||
f.logger.Debug("starting to fetch and process arrow batches", | ||
zap.Int("batches", batchesLeft), zap.Int("parallel_fetch_limit", f.parallelFetchLimit)) | ||
|
||
// Fetch batches async as it takes most of the time | ||
var wg sync.WaitGroup | ||
fetchResultChan := make(chan fetchResult, len(f.batches)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the max number of batches? Wondering if this might take a lot of memory. If the writing phase is fast, I think we should avoid a buffered channel here. See: https://github.com/uber-go/guide/blob/master/style.md#channel-size-is-one-or-none There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no way to control the number of batches and number of records per batch in Snowflake. There might be 1K batches for 100M rows |
||
sem := semaphore.NewWeighted(int64(f.parallelFetchLimit)) | ||
for _, batch := range f.batches { | ||
records, err := batch.Fetch() | ||
if err != nil { | ||
return nil, err | ||
wg.Add(1) | ||
go func(b *sf.ArrowBatch) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question – how many batches might there be? Starting 10s-100s of Goroutines is okay, but probably worth optimizing if looking at thousands. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if it might be simpler to use an grp, ctx := errgroup.WithContext(ctx)
grp.SetLimit(f.parallelFetchLimit)
go func() {
for {
select {
case res, ok <-resultChan:
// Write...
case <-ctx.Done():
// Either finished or errored
return
}
}
}()
for _, batch := range f.batches {
grp.Go(...)
}
err := wg.Wait()
// done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Converted using |
||
defer wg.Done() | ||
defer sem.Release(1) | ||
err := sem.Acquire(f.ctx, 1) | ||
if err != nil { | ||
fetchResultChan <- fetchResult{Records: nil, Batch: nil, Err: err} | ||
return | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think releases should be after successful acquire (else might panic) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed the use of a semaphore |
||
fetchStart := time.Now() | ||
records, err := b.Fetch() | ||
fetchResultChan <- fetchResult{Records: records, Batch: b, Err: err} | ||
f.logger.Debug( | ||
"fetched an arrow batch", | ||
zap.Duration("duration", time.Since(fetchStart)), | ||
zap.Int("row_count", b.GetRowCount()), | ||
) | ||
}(batch) | ||
} | ||
|
||
go func() { | ||
wg.Wait() | ||
close(fetchResultChan) | ||
}() | ||
|
||
for result := range fetchResultChan { | ||
if result.Err != nil { | ||
return nil, result.Err | ||
} | ||
f.totalRecords += int64(batch.GetRowCount()) | ||
for _, rec := range *records { | ||
batch := result.Batch | ||
writeStart := time.Now() | ||
for _, rec := range *result.Records { | ||
if writer.RowGroupTotalBytesWritten() >= rowGroupBufferSize { | ||
writer.NewBufferedRowGroup() | ||
} | ||
|
@@ -204,6 +258,14 @@ func (f *fileIterator) Next() ([]string, error) { | |
} | ||
} | ||
} | ||
batchesLeft-- | ||
f.logger.Debug( | ||
"wrote an arrow batch to a parquet file", | ||
zap.Float64("progress", float64(len(f.batches)-batchesLeft)/float64(len(f.batches))*100), | ||
zap.Int("row_count", batch.GetRowCount()), | ||
zap.Duration("write_duration", time.Since(writeStart)), | ||
) | ||
f.totalRecords += int64(result.Batch.GetRowCount()) | ||
} | ||
|
||
writer.Close() | ||
|
@@ -237,6 +299,12 @@ func (f *fileIterator) Format() string { | |
|
||
var _ drivers.FileIterator = &fileIterator{} | ||
|
||
type fetchResult struct { | ||
Records *[]arrow.Record | ||
Batch *sf.ArrowBatch | ||
Err error | ||
} | ||
|
||
type sourceProperties struct { | ||
SQL string `mapstructure:"sql"` | ||
DSN string `mapstructure:"dsn"` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a standard Snowflake property? If not, I think it would be better to set it as a separate connector property – i.e. configured using
--var connector.snowflake.parallel_fetch_limit=N
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it isn't a Snowflake property
Moved to env var