diff --git a/async.go b/async.go index 5c477dc1a..363792f30 100644 --- a/async.go +++ b/async.go @@ -63,25 +63,45 @@ func (sr *snowflakeRestful) getAsync( defer close(errChannel) token, _, _ := sr.TokenAccessor.GetTokens() headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) - resp, err := sr.FuncGet(ctx, sr, URL, headers, timeout) - if err != nil { - logger.WithContext(ctx).Errorf("failed to get response. err: %v", err) - sfError.Message = err.Error() - errChannel <- sfError - return err - } - if resp.Body != nil { + + var err error + var respd execResponse + retry := 0 + retryPattern := []int32{1, 1, 2, 3, 4, 8, 10} + + for { + resp, err := sr.FuncGet(ctx, sr, URL, headers, timeout) + if err != nil { + logger.WithContext(ctx).Errorf("failed to get response. err: %v", err) + sfError.Message = err.Error() + errChannel <- sfError + return err + } defer resp.Body.Close() - } - respd := execResponse{} - err = json.NewDecoder(resp.Body).Decode(&respd) - resp.Body.Close() - if err != nil { - logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) - sfError.Message = err.Error() - errChannel <- sfError - return err + respd = execResponse{} // reset the response + err = json.NewDecoder(resp.Body).Decode(&respd) + if err != nil { + logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err) + sfError.Message = err.Error() + errChannel <- sfError + return err + } + if respd.Code != queryInProgressAsyncCode { + // If the query takes longer than 45 seconds to complete the results are not returned. + // If the query is still in progress after 45 seconds, retry the request to the /results endpoint. + // For all other scenarios continue processing results response + break + } else { + // Sleep before retrying get result request. Exponential backoff up to 5 seconds. + sleepTime := time.Millisecond * time.Duration(500*retryPattern[retry]) + logger.WithContext(ctx).Infof("Query execution still in progress. Sleep for %v ms", sleepTime) + time.Sleep(sleepTime) + } + if retry < len(retryPattern)-1 { + retry++ + } + } sc := &snowflakeConn{rest: sr, cfg: cfg} diff --git a/async_test.go b/async_test.go index b2eae18cf..d69619c8a 100644 --- a/async_test.go +++ b/async_test.go @@ -136,3 +136,38 @@ func retrieveRows(rows *RowsExtended, ch chan string) { ch <- s close(ch) } + +func TestLongRunningAsyncQuery(t *testing.T) { + db := openDB(t) + defer db.Close() + + ctx, _ := WithMultiStatement(context.Background(), 0) + query := "CALL SYSTEM$WAIT(50, 'SECONDS');use snowflake_sample_data" + + rows, err := db.QueryContext(WithAsyncMode(ctx), query) + if err != nil { + t.Fatalf("failed to run a query. %v, err: %v", query, err) + } + defer rows.Close() + var v string + i := 0 + for { + for rows.Next() { + err := rows.Scan(&v) + if err != nil { + t.Fatalf("failed to get result. err: %v", err) + } + if v == "" { + t.Fatal("should have returned a result") + } + results := []string{"waited 50 seconds", "Statement executed successfully."} + if v != results[i] { + t.Fatalf("unexpected result returned. expected: %v, but got: %v", results[i], v) + } + i++ + } + if !rows.NextResultSet() { + break + } + } +}