From 5671acfd00ed246ca29fa4d09884728ff8234f89 Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Thu, 29 Jun 2023 13:06:11 +0200 Subject: [PATCH] SNOW-833537 Configure retries of JWT token auth --- connection.go | 2 +- dsn.go | 30 ++++- dsn_test.go | 314 ++++++++++++++++++++++++++++++++++---------------- util.go | 22 ++++ 4 files changed, 266 insertions(+), 102 deletions(-) diff --git a/connection.go b/connection.go index 84c2cfe40..5712c3701 100644 --- a/connection.go +++ b/connection.go @@ -73,7 +73,7 @@ type snowflakeConn struct { var ( queryIDPattern = `[\w\-_]+` queryIDRegexp = regexp.MustCompile(queryIDPattern) - errMutex = &sync.Mutex{} + errMutex = &sync.Mutex{} ) func (sc *snowflakeConn) exec( diff --git a/dsn.go b/dsn.go index d2028d9c9..02af2f96f 100644 --- a/dsn.go +++ b/dsn.go @@ -13,6 +13,7 @@ import ( "net/http" "net/url" "os" + "reflect" "strconv" "strings" "time" @@ -23,10 +24,13 @@ const ( defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout defaultJWTTimeout = 60 * time.Second + defaultJWTMaxRetries = 10 defaultExternalBrowserTimeout = 120 * time.Second // Timeout for external browser login defaultDomain = ".snowflakecomputing.com" ) +var defaultJWTRetryForCodes = [1]int{390144} + // ConfigBool is a type to represent true or false in the Config type ConfigBool uint8 @@ -81,7 +85,9 @@ type Config struct { TokenAccessor TokenAccessor // Optional token accessor to use KeepSessionAlive bool // Enables the session to persist even after the connection is closed - PrivateKey *rsa.PrivateKey // Private key used to sign JWT + PrivateKey *rsa.PrivateKey // Private key used to sign JWT + JWTMaxRetries int // How many retries should be performed before failing authentication + JWTRetryForCodes []int // Which error codes should end with auth retry Transporter http.RoundTripper // RoundTripper to intercept HTTP requests and responses @@ -178,6 +184,12 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.JWTExpireTimeout != defaultJWTTimeout { params.Add("jwtTimeout", strconv.FormatInt(int64(cfg.JWTExpireTimeout/time.Second), 10)) } + if cfg.JWTMaxRetries != defaultJWTMaxRetries { + params.Add("jwtMaxRetries", strconv.Itoa(cfg.JWTMaxRetries)) + } + if !reflect.DeepEqual(cfg.JWTRetryForCodes, defaultJWTRetryForCodes) { + params.Add("jwtRetryForCodes", strings.Join(toStringSlice(cfg.JWTRetryForCodes), ",")) + } if cfg.Application != clientType { params.Add("application", cfg.Application) } @@ -428,6 +440,12 @@ func fillMissingConfigParameters(cfg *Config) error { if cfg.JWTExpireTimeout == 0 { cfg.JWTExpireTimeout = defaultJWTTimeout } + if cfg.JWTMaxRetries == 0 { + cfg.JWTMaxRetries = defaultJWTMaxRetries + } + if len(cfg.JWTRetryForCodes) == 0 { + cfg.JWTRetryForCodes = defaultJWTRetryForCodes[:] + } if cfg.ClientTimeout == 0 { cfg.ClientTimeout = defaultClientTimeout } @@ -578,6 +596,16 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return err } + case "jwtMaxRetries": + cfg.JWTMaxRetries, err = strconv.Atoi(value) + if err != nil { + return err + } + case "jwtRetryForCodes": + cfg.JWTRetryForCodes, err = parseToIntArray(value) + if err != nil { + return err + } case "application": cfg.Application = value case "authenticator": diff --git a/dsn_test.go b/dsn_test.go index 0798f7698..76b529df5 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -38,6 +38,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -51,6 +54,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -62,6 +68,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -74,6 +83,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -87,6 +99,9 @@ func TestParseDSN(t *testing.T) { ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -100,6 +115,9 @@ func TestParseDSN(t *testing.T) { ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -112,6 +130,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -124,6 +145,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -136,6 +160,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -149,6 +176,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -162,6 +192,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -175,6 +208,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: ErrEmptyPassword, @@ -188,6 +224,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: ErrEmptyUsername, @@ -201,6 +240,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: ErrEmptyAccount, @@ -214,6 +256,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -226,6 +271,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -238,6 +286,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -250,21 +301,26 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, }, { - dsn: "u:p@a?database=d&jwtTimeout=20", + dsn: "u:p@a?database=d&jwtTimeout=20&jwtMaxRetries=5&jwtRetryForCodes=1,2,3", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, Database: "d", Schema: "", JWTExpireTimeout: 20 * time.Second, + JWTMaxRetries: 5, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTRetryForCodes: []int{1, 2, 3}, }, ocspMode: ocspModeFailOpen, }, @@ -274,10 +330,12 @@ func TestParseDSN(t *testing.T) { Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, Database: "d", Schema: "", - JWTExpireTimeout: defaultJWTTimeout, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, }, @@ -289,6 +347,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: &SnowflakeError{ @@ -307,6 +368,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeInsecure, err: nil, @@ -321,6 +385,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -370,6 +437,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -383,6 +453,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -401,6 +474,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -414,6 +490,9 @@ func TestParseDSN(t *testing.T) { OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: &SnowflakeError{Number: ErrCodePrivateKeyParseError}, @@ -426,6 +505,9 @@ func TestParseDSN(t *testing.T) { Database: "db", Schema: "s", OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -438,6 +520,9 @@ func TestParseDSN(t *testing.T) { Database: "db", Schema: "s", OCSPFailOpen: OCSPFailOpenFalse, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailClosed, err: nil, @@ -450,6 +535,9 @@ func TestParseDSN(t *testing.T) { Database: "db", Schema: "s", OCSPFailOpen: OCSPFailOpenFalse, InsecureMode: true, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeInsecure, err: nil, @@ -460,7 +548,10 @@ func TestParseDSN(t *testing.T) { Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: defaultClientTimeout, + ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -471,7 +562,10 @@ func TestParseDSN(t *testing.T) { Account: "account", User: "user", Password: "pass", Protocol: "https", Host: "account.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolFalse, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: defaultClientTimeout, + ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -482,7 +576,10 @@ func TestParseDSN(t *testing.T) { Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolFalse, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: defaultClientTimeout, + ClientTimeout: defaultClientTimeout, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -493,7 +590,10 @@ func TestParseDSN(t *testing.T) { Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: 300 * time.Second, + ClientTimeout: 300 * time.Second, + JWTExpireTimeout: defaultJWTTimeout, + JWTMaxRetries: defaultJWTMaxRetries, + JWTRetryForCodes: defaultJWTRetryForCodes[:], }, ocspMode: ocspModeFailOpen, err: nil, @@ -501,103 +601,117 @@ func TestParseDSN(t *testing.T) { } for i, test := range testcases { - // t.Logf("Parsing testcase %d, DSN: %s", i, test.dsn) - cfg, err := ParseDSN(test.dsn) - switch { - case test.err == nil: - if err != nil { - t.Fatalf("%d: Failed to parse the DSN. dsn: %v, err: %v", i, test.dsn, err) - } - if test.config.Host != cfg.Host { - t.Fatalf("%d: Failed to match host. expected: %v, got: %v", - i, test.config.Host, cfg.Host) - } - if test.config.Account != cfg.Account { - t.Fatalf("%d: Failed to match account. expected: %v, got: %v", - i, test.config.Account, cfg.Account) - } - if test.config.User != cfg.User { - t.Fatalf("%d: Failed to match user. expected: %v, got: %v", - i, test.config.User, cfg.User) - } - if test.config.Password != cfg.Password { - t.Fatalf("%d: Failed to match password. expected: %v, got: %v", - i, test.config.Password, cfg.Password) - } - if test.config.Database != cfg.Database { - t.Fatalf("%d: Failed to match database. expected: %v, got: %v", - i, test.config.Database, cfg.Database) - } - if test.config.Schema != cfg.Schema { - t.Fatalf("%d: Failed to match schema. expected: %v, got: %v", - i, test.config.Schema, cfg.Schema) - } - if test.config.Warehouse != cfg.Warehouse { - t.Fatalf("%d: Failed to match warehouse. expected: %v, got: %v", - i, test.config.Warehouse, cfg.Warehouse) - } - if test.config.Role != cfg.Role { - t.Fatalf("%d: Failed to match role. expected: %v, got: %v", - i, test.config.Role, cfg.Role) - } - if test.config.Region != cfg.Region { - t.Fatalf("%d: Failed to match region. expected: %v, got: %v", - i, test.config.Region, cfg.Region) - } - if test.config.Protocol != cfg.Protocol { - t.Fatalf("%d: Failed to match protocol. expected: %v, got: %v", - i, test.config.Protocol, cfg.Protocol) - } - if test.config.Passcode != cfg.Passcode { - t.Fatalf("%d: Failed to match passcode. expected: %v, got: %v", - i, test.config.Passcode, cfg.Passcode) - } - if test.config.PasscodeInPassword != cfg.PasscodeInPassword { - t.Fatalf("%d: Failed to match passcodeInPassword. expected: %v, got: %v", - i, test.config.PasscodeInPassword, cfg.PasscodeInPassword) - } - if test.config.Authenticator != cfg.Authenticator { - t.Fatalf("%d: Failed to match Authenticator. expected: %v, got: %v", - i, test.config.Authenticator.String(), cfg.Authenticator.String()) - } - if test.config.Authenticator == AuthTypeOkta && *test.config.OktaURL != *cfg.OktaURL { - t.Fatalf("%d: Failed to match okta URL. expected: %v, got: %v", - i, test.config.OktaURL, cfg.OktaURL) - } - if test.config.OCSPFailOpen != cfg.OCSPFailOpen { - t.Fatalf("%d: Failed to match OCSPFailOpen. expected: %v, got: %v", - i, test.config.OCSPFailOpen, cfg.OCSPFailOpen) - } - if test.ocspMode != cfg.ocspMode() { - t.Fatalf("%d: Failed to match OCSPMode. expected: %v, got: %v", - i, test.ocspMode, cfg.ocspMode()) - } - if test.config.ValidateDefaultParameters != cfg.ValidateDefaultParameters { - t.Fatalf("%d: Failed to match ValidateDefaultParameters. expected: %v, got: %v", - i, test.config.ValidateDefaultParameters, cfg.ValidateDefaultParameters) - } - if test.config.ClientTimeout != cfg.ClientTimeout { - t.Fatalf("%d: Failed to match ClientTimeout. expected: %v, got: %v", - i, test.config.ClientTimeout, cfg.ClientTimeout) - } - case test.err != nil: - driverErrE, okE := test.err.(*SnowflakeError) - driverErrG, okG := err.(*SnowflakeError) - if okE && !okG || !okE && okG { - t.Fatalf("%d: Wrong error. expected: %v, got: %v", i, test.err, err) - } - if okE && okG { - if driverErrE.Number != driverErrG.Number { - t.Fatalf("%d: Wrong error number. expected: %v, got: %v", i, driverErrE.Number, driverErrG.Number) + t.Run("TestParseDSN", func(t *testing.T) { + // t.Logf("Parsing testcase %d, DSN: %s", i, test.dsn) + cfg, err := ParseDSN(test.dsn) + switch { + case test.err == nil: + if err != nil { + t.Fatalf("%d: Failed to parse the DSN. dsn: %v, err: %v", i, test.dsn, err) + } + if test.config.Host != cfg.Host { + t.Fatalf("%d: Failed to match host. expected: %v, got: %v", + i, test.config.Host, cfg.Host) + } + if test.config.Account != cfg.Account { + t.Fatalf("%d: Failed to match account. expected: %v, got: %v", + i, test.config.Account, cfg.Account) + } + if test.config.User != cfg.User { + t.Fatalf("%d: Failed to match user. expected: %v, got: %v", + i, test.config.User, cfg.User) + } + if test.config.Password != cfg.Password { + t.Fatalf("%d: Failed to match password. expected: %v, got: %v", + i, test.config.Password, cfg.Password) + } + if test.config.Database != cfg.Database { + t.Fatalf("%d: Failed to match database. expected: %v, got: %v", + i, test.config.Database, cfg.Database) } - } else { - t1 := reflect.TypeOf(err) - t2 := reflect.TypeOf(test.err) - if t1 != t2 { - t.Fatalf("%d: Wrong error. expected: %T:%v, got: %T:%v", i, test.err, test.err, err, err) + if test.config.Schema != cfg.Schema { + t.Fatalf("%d: Failed to match schema. expected: %v, got: %v", + i, test.config.Schema, cfg.Schema) + } + if test.config.Warehouse != cfg.Warehouse { + t.Fatalf("%d: Failed to match warehouse. expected: %v, got: %v", + i, test.config.Warehouse, cfg.Warehouse) + } + if test.config.Role != cfg.Role { + t.Fatalf("%d: Failed to match role. expected: %v, got: %v", + i, test.config.Role, cfg.Role) + } + if test.config.Region != cfg.Region { + t.Fatalf("%d: Failed to match region. expected: %v, got: %v", + i, test.config.Region, cfg.Region) + } + if test.config.Protocol != cfg.Protocol { + t.Fatalf("%d: Failed to match protocol. expected: %v, got: %v", + i, test.config.Protocol, cfg.Protocol) + } + if test.config.Passcode != cfg.Passcode { + t.Fatalf("%d: Failed to match passcode. expected: %v, got: %v", + i, test.config.Passcode, cfg.Passcode) + } + if test.config.PasscodeInPassword != cfg.PasscodeInPassword { + t.Fatalf("%d: Failed to match passcodeInPassword. expected: %v, got: %v", + i, test.config.PasscodeInPassword, cfg.PasscodeInPassword) + } + if test.config.Authenticator != cfg.Authenticator { + t.Fatalf("%d: Failed to match Authenticator. expected: %v, got: %v", + i, test.config.Authenticator.String(), cfg.Authenticator.String()) + } + if test.config.Authenticator == AuthTypeOkta && *test.config.OktaURL != *cfg.OktaURL { + t.Fatalf("%d: Failed to match okta URL. expected: %v, got: %v", + i, test.config.OktaURL, cfg.OktaURL) + } + if test.config.OCSPFailOpen != cfg.OCSPFailOpen { + t.Fatalf("%d: Failed to match OCSPFailOpen. expected: %v, got: %v", + i, test.config.OCSPFailOpen, cfg.OCSPFailOpen) + } + if test.ocspMode != cfg.ocspMode() { + t.Fatalf("%d: Failed to match OCSPMode. expected: %v, got: %v", + i, test.ocspMode, cfg.ocspMode()) + } + if test.config.ValidateDefaultParameters != cfg.ValidateDefaultParameters { + t.Fatalf("%d: Failed to match ValidateDefaultParameters. expected: %v, got: %v", + i, test.config.ValidateDefaultParameters, cfg.ValidateDefaultParameters) + } + if test.config.ClientTimeout != cfg.ClientTimeout { + t.Fatalf("%d: Failed to match ClientTimeout. expected: %v, got: %v", + i, test.config.ClientTimeout, cfg.ClientTimeout) + } + if test.config.JWTExpireTimeout != cfg.JWTExpireTimeout { + t.Fatalf("%d: Failed to match JWTExpireTimeout. epxected: %v, got: %v", + i, test.config.JWTExpireTimeout, cfg.JWTExpireTimeout) + } + if test.config.JWTMaxRetries != cfg.JWTMaxRetries { + t.Fatalf("%d: Failed to match JWTMaxRetries. expected: %v, got: %v", + i, test.config.JWTMaxRetries, cfg.JWTMaxRetries) + } + if !reflect.DeepEqual(test.config.JWTRetryForCodes, cfg.JWTRetryForCodes) { + t.Fatalf("%d: Failed to match JWTRetryCodes. expected: %v, got: %v", + i, test.config.JWTRetryForCodes, cfg.JWTRetryForCodes) + } + case test.err != nil: + driverErrE, okE := test.err.(*SnowflakeError) + driverErrG, okG := err.(*SnowflakeError) + if okE && !okG || !okE && okG { + t.Fatalf("%d: Wrong error. expected: %v, got: %v", i, test.err, err) + } + if okE && okG { + if driverErrE.Number != driverErrG.Number { + t.Fatalf("%d: Wrong error number. expected: %v, got: %v", i, driverErrE.Number, driverErrG.Number) + } + } else { + t1 := reflect.TypeOf(err) + t2 := reflect.TypeOf(test.err) + if t1 != t2 { + t.Fatalf("%d: Wrong error. expected: %T:%v, got: %T:%v", i, test.err, test.err, err, err) + } } } - } + }) } } diff --git a/util.go b/util.go index d112189ff..5ab60ad2d 100644 --- a/util.go +++ b/util.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "os" + "strconv" "strings" "sync" "time" @@ -235,3 +236,24 @@ func GetFromEnv(name string, failOnMissing bool) (string, error) { } return "", nil } + +func toStringSlice(arr []int) []string { + strArr := make([]string, len(arr)) + for i, val := range arr { + strArr[i] = strconv.Itoa(val) + } + return strArr +} + +func parseToIntArray(value string) ([]int, error) { + codesAsStr := strings.Split(value, ",") + codes := make([]int, len(codesAsStr)) + for i, valAsStr := range codesAsStr { + val, err := strconv.Atoi(valAsStr) + if err != nil { + return nil, err + } + codes[i] = val + } + return codes, nil +}