diff --git a/go.mod b/go.mod index 2b4c02e28ae..2f0dcab62f0 100644 --- a/go.mod +++ b/go.mod @@ -69,7 +69,7 @@ require ( github.com/robfig/cron/v3 v3.0.1 github.com/rs/cors v1.9.0 github.com/sashabaranov/go-openai v1.19.3 - github.com/snowflakedb/gosnowflake v1.7.2 + github.com/snowflakedb/gosnowflake v1.8.0 github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 @@ -316,4 +316,10 @@ replace github.com/marcboeker/go-duckdb v1.5.5 => github.com/rilldata/go-duckdb // Fixes a security warning. Remove when testcontainers-go v0.27.0 is released. replace github.com/testcontainers/testcontainers-go v0.26.0 => github.com/testcontainers/testcontainers-go v0.26.1-0.20231102155908-6aac7412c81a +// gosnowflake v1.8.0 has an issue with arrow batches - it retunrs 0 batches if the first batch has no records +// see a corresponding PR for details: https://github.com/snowflakedb/gosnowflake/pull/1068 +// the issue is supposed to be fixed in v1.8.1 but make sure apache/arrow/go/v15 doesn't cause any breaking changes +// see the following PR: https://github.com/snowflakedb/gosnowflake/pull/1062 +replace github.com/snowflakedb/gosnowflake v1.8.0 => github.com/snowflakedb/gosnowflake v1.8.1-0.20240311092318-48c5e93a4d51 + exclude modernc.org/sqlite v1.18.1 diff --git a/go.sum b/go.sum index 9039e6854e7..607caf9db4b 100644 --- a/go.sum +++ b/go.sum @@ -1998,8 +1998,8 @@ github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod h1:s github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg= github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM= -github.com/snowflakedb/gosnowflake v1.7.2 h1:HRSwva8YXC64WUppfmHcMNVVzSE1+EwXXaJxgS0EkTo= -github.com/snowflakedb/gosnowflake v1.7.2/go.mod h1:03tW856vc3ceM4rJuj7KO4dzqN7qoezTm+xw7aPIIFo= +github.com/snowflakedb/gosnowflake v1.8.1-0.20240311092318-48c5e93a4d51 h1:ARZUHyxhujUe4cPUzGGo0p4nNHVlClPUC9u/4EhkMuE= +github.com/snowflakedb/gosnowflake v1.8.1-0.20240311092318-48c5e93a4d51/go.mod h1:7yyY2MxtDti2eXgtvlZ8QxzCN6KV2B4qb1HuygMI+0U= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= diff --git a/runtime/drivers/snowflake/sql_store.go b/runtime/drivers/snowflake/sql_store.go index 60ccb77fe84..3c010f97e7f 100644 --- a/runtime/drivers/snowflake/sql_store.go +++ b/runtime/drivers/snowflake/sql_store.go @@ -8,6 +8,7 @@ import ( "io" "os" "strconv" + "sync" "time" "github.com/apache/arrow/go/v14/arrow" @@ -201,76 +202,55 @@ func (f *fileIterator) Next() ([]string, error) { zap.Int("batches", len(f.batches)), zap.Int("parallel_fetch_limit", f.parallelFetchLimit)) // Fetch batches async - fetchGrp, ctx := errgroup.WithContext(f.ctx) - fetchGrp.SetLimit(f.parallelFetchLimit) - fetchResultChan := make(chan fetchResult) - - // Write batches into a file async - writeGrp, _ := errgroup.WithContext(f.ctx) - writeGrp.Go(func() error { - batchesLeft := len(f.batches) - for { - select { - case result, ok := <-fetchResultChan: - if !ok { - return nil - } - batch := result.batch - writeStart := time.Now() - for _, rec := range *result.records { - if writer.RowGroupTotalBytesWritten() >= rowGroupBufferSize { - writer.NewBufferedRowGroup() - } - if err := writer.WriteBuffered(rec); err != nil { - return err - } - fileInfo, err := os.Stat(fw.Name()) - if err == nil { // ignore error - if fileInfo.Size() > f.limitInBytes { - return drivers.ErrStorageLimitExceeded - } - } - } - batchesLeft-- - f.logger.Debug( - "wrote an arrow batch to a parquet file", - zap.Float64("progress", float64(len(f.batches)-batchesLeft)/float64(len(f.batches))*100), - zap.Int("row_count", batch.GetRowCount()), - zap.Duration("write_duration", time.Since(writeStart)), - ) - f.totalRecords += int64(result.batch.GetRowCount()) - case <-ctx.Done(): - return ctx.Err() - } - } - }) + errGrp, _ := errgroup.WithContext(f.ctx) + errGrp.SetLimit(f.parallelFetchLimit) + // mutex to protect file writes + var mu sync.Mutex + batchesLeft := len(f.batches) for _, batch := range f.batches { b := batch - fetchGrp.Go(func() error { + errGrp.Go(func() error { fetchStart := time.Now() records, err := b.Fetch() if err != nil { return err } - fetchResultChan <- fetchResult{records: records, batch: b} f.logger.Debug( "fetched an arrow batch", zap.Duration("duration", time.Since(fetchStart)), zap.Int("row_count", b.GetRowCount()), ) + mu.Lock() + defer mu.Unlock() + writeStart := time.Now() + for _, rec := range *records { + if writer.RowGroupTotalBytesWritten() >= rowGroupBufferSize { + writer.NewBufferedRowGroup() + } + if err := writer.WriteBuffered(rec); err != nil { + return err + } + fileInfo, err := os.Stat(fw.Name()) + if err == nil { // ignore error + if fileInfo.Size() > f.limitInBytes { + return drivers.ErrStorageLimitExceeded + } + } + } + batchesLeft-- + f.logger.Debug( + "wrote an arrow batch to a parquet file", + zap.Float64("progress", float64(len(f.batches)-batchesLeft)/float64(len(f.batches))*100), + zap.Int("row_count", b.GetRowCount()), + zap.Duration("write_duration", time.Since(writeStart)), + ) + f.totalRecords += int64(b.GetRowCount()) return nil }) } - err = fetchGrp.Wait() - close(fetchResultChan) - - if err != nil { - return nil, err - } - - if err := writeGrp.Wait(); err != nil { + if err := errGrp.Wait(); err != nil { return nil, err } @@ -305,11 +285,6 @@ func (f *fileIterator) Format() string { var _ drivers.FileIterator = &fileIterator{} -type fetchResult struct { - records *[]arrow.Record - batch *sf.ArrowBatch -} - type sourceProperties struct { SQL string `mapstructure:"sql"` DSN string `mapstructure:"dsn"`