Skip to content

Commit

Permalink
SNOW-833537 Configure retries of JWT token auth
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Jun 29, 2023
1 parent 9cbda08 commit eaafe5f
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 119 deletions.
2 changes: 1 addition & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ type snowflakeConn struct {
var (
queryIDPattern = `[\w\-_]+`
queryIDRegexp = regexp.MustCompile(queryIDPattern)
errMutex = &sync.Mutex{}
errMutex = &sync.Mutex{}
)

func (sc *snowflakeConn) exec(
Expand Down
30 changes: 29 additions & 1 deletion dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/http"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"time"
Expand All @@ -23,10 +24,13 @@ const (
defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout
defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout
defaultJWTTimeout = 60 * time.Second
defaultJWTMaxRetries = 10
defaultExternalBrowserTimeout = 120 * time.Second // Timeout for external browser login
defaultDomain = ".snowflakecomputing.com"
)

var defaultJWTRetryForCodes = [1]int{390144}

// ConfigBool is a type to represent true or false in the Config
type ConfigBool uint8

Expand Down Expand Up @@ -81,7 +85,9 @@ type Config struct {
TokenAccessor TokenAccessor // Optional token accessor to use
KeepSessionAlive bool // Enables the session to persist even after the connection is closed

PrivateKey *rsa.PrivateKey // Private key used to sign JWT
PrivateKey *rsa.PrivateKey // Private key used to sign JWT
JWTMaxRetries int // How many retries should be performed before failing authentication
JWTRetryForCodes []int // Which error codes should end with auth retry

Transporter http.RoundTripper // RoundTripper to intercept HTTP requests and responses

Expand Down Expand Up @@ -178,6 +184,12 @@ func DSN(cfg *Config) (dsn string, err error) {
if cfg.JWTExpireTimeout != defaultJWTTimeout {
params.Add("jwtTimeout", strconv.FormatInt(int64(cfg.JWTExpireTimeout/time.Second), 10))
}
if cfg.JWTMaxRetries != defaultJWTMaxRetries {
params.Add("jwtMaxRetries", strconv.Itoa(cfg.JWTMaxRetries))
}
if !reflect.DeepEqual(cfg.JWTRetryForCodes, defaultJWTRetryForCodes[:]) {
params.Add("jwtRetryForCodes", strings.Join(toStringSlice(cfg.JWTRetryForCodes), ","))
}
if cfg.Application != clientType {
params.Add("application", cfg.Application)
}
Expand Down Expand Up @@ -428,6 +440,12 @@ func fillMissingConfigParameters(cfg *Config) error {
if cfg.JWTExpireTimeout == 0 {
cfg.JWTExpireTimeout = defaultJWTTimeout
}
if cfg.JWTMaxRetries == 0 {
cfg.JWTMaxRetries = defaultJWTMaxRetries
}
if len(cfg.JWTRetryForCodes) == 0 {
cfg.JWTRetryForCodes = defaultJWTRetryForCodes[:]
}
if cfg.ClientTimeout == 0 {
cfg.ClientTimeout = defaultClientTimeout
}
Expand Down Expand Up @@ -578,6 +596,16 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return err
}
case "jwtMaxRetries":
cfg.JWTMaxRetries, err = strconv.Atoi(value)
if err != nil {
return err
}
case "jwtRetryForCodes":
cfg.JWTRetryForCodes, err = parseToIntArray(value)
if err != nil {
return err
}
case "application":
cfg.Application = value
case "authenticator":
Expand Down
Loading

0 comments on commit eaafe5f

Please sign in to comment.