diff --git a/connection_configuration.go b/connection_configuration.go index 491e2f9fd..86474adda 100644 --- a/connection_configuration.go +++ b/connection_configuration.go @@ -14,6 +14,11 @@ import ( toml "github.com/BurntSushi/toml" ) +const ( + connectionName = "SNOWFLAKE_DEFAULT_CONNECTION_NAME" + home = "SNOWFLAKE_HOME" +) + // LoadConnectionConfig returns connection configs loaded from the toml file. // By default, SNOWFLAKE_HOME(toml file path) is os.home/snowflake // and SNOWFLAKE_DEFAULT_CONNECTION_NAME(DSN) is 'default' @@ -22,8 +27,8 @@ func LoadConnectionConfig() (*Config, error) { Params: make(map[string]*string), Authenticator: AuthTypeSnowflake, // Default to snowflake } - dsn := getConnectionDSN(os.Getenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME")) - snowflakeConfigDir, err := getTomlFilePath(os.Getenv("SNOWFLAKE_HOME")) + dsn := getConnectionDSN(os.Getenv(connectionName)) + snowflakeConfigDir, err := getTomlFilePath(os.Getenv(home)) if err != nil { return nil, err } @@ -56,276 +61,145 @@ func LoadConnectionConfig() (*Config, error) { } func parseToml(cfg *Config, connection map[string]interface{}) error { - var v, tokenPath string - var parsingErr error - var vv bool - err := &SnowflakeError{ - Number: ErrCodeTomlFileParsingFailed, - Message: errMsgFailedToParseTomlFile, - } for key, value := range connection { - switch strings.ToLower(key) { - case "user", "username": - if cfg.User, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "password": - if cfg.Password, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "host": - if cfg.Host, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "account": - if cfg.Account, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "warehouse": - if cfg.Warehouse, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "database": - if cfg.Database, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "schema": - if cfg.Schema, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "role": - if cfg.Role, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "region": - if cfg.Region, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "protocol": - if cfg.Protocol, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "passcode": - if cfg.Passcode, parsingErr = parseString(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "port": - if cfg.Port, parsingErr = parseInt(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "passcodeinpassword": - if cfg.PasscodeInPassword, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "clienttimeout": - if cfg.ClientTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "jwtclienttimeout": - if cfg.JWTClientTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "logintimeout": - if cfg.LoginTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "requesttimeout": - if cfg.RequestTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "jwttimeout": - if cfg.JWTExpireTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "externalbrowsertimeout": - if cfg.ExternalBrowserTimeout, parsingErr = parseDuration(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "maxretrycount": - if cfg.MaxRetryCount, parsingErr = parseInt(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "application": - cfg.Application, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "authenticator": - v, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - parsingErr = determineAuthenticatorType(cfg, v) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "insecuremode": - if cfg.InsecureMode, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "ocspfailopen": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.OCSPFailOpen = OCSPFailOpenTrue - } else { - cfg.OCSPFailOpen = OCSPFailOpenFalse - } - - case "token": - cfg.Token, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "privatekey": - v, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - block, decodeErr := base64.URLEncoding.DecodeString(v) - if decodeErr != nil { - return &SnowflakeError{ - Number: ErrCodePrivateKeyParseError, - Message: "Base64 decode failed", - } - } - cfg.PrivateKey, parsingErr = parsePKCS8PrivateKey(block) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "validatedefaultparameters": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.ValidateDefaultParameters = ConfigBoolTrue - } else { - cfg.ValidateDefaultParameters = ConfigBoolFalse - } - case "clientrequestmfatoken": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.ClientRequestMfaToken = ConfigBoolTrue - } else { - cfg.ClientRequestMfaToken = ConfigBoolFalse - } - case "clientstoretemporarycredential": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.ClientStoreTemporaryCredential = ConfigBoolTrue - } else { - cfg.ClientStoreTemporaryCredential = ConfigBoolFalse - } - case "tracing": - cfg.Tracing, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "tmpdirpath": - cfg.TmpDirPath, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "disablequerycontextcache": - if cfg.DisableQueryContextCache, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "includeretryreason": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.IncludeRetryReason = ConfigBoolTrue - } else { - cfg.IncludeRetryReason = ConfigBoolFalse - } - case "clientconfigfile": - cfg.ClientConfigFile, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - case "disableconsolelogin": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.DisableConsoleLogin = ConfigBoolTrue - } else { - cfg.DisableConsoleLogin = ConfigBoolFalse - } - case "disablesamlurlcheck": - if vv, parsingErr = parseBool(value); parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - if vv { - cfg.DisableSamlURLCheck = ConfigBoolTrue - } else { - cfg.DisableSamlURLCheck = ConfigBoolFalse - } - case "token_file_path": - tokenPath, parsingErr = parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - default: - param, parsingErr := parseString(value) - if parsingErr != nil { - err.MessageArgs = []interface{}{key, value} - return err - } - cfg.Params[urlDecodeIfNeeded(key)] = ¶m + if err := handleSingleParam(cfg, key, value); err != nil { + return err } } if shouldReadTokenFromFile(cfg) { + v, err := readToken("") + if err != nil { + return err + } + cfg.Token = v + } + return nil +} + +func handleSingleParam(cfg *Config, key string, value interface{}) error { + var parsingErr error + var v, tokenPath string + switch strings.ToLower(key) { + case "user", "username": + cfg.User, parsingErr = parseString(value) + case "password": + cfg.Password, parsingErr = parseString(value) + case "host": + cfg.Host, parsingErr = parseString(value) + case "account": + cfg.Account, parsingErr = parseString(value) + case "warehouse": + cfg.Warehouse, parsingErr = parseString(value) + case "database": + cfg.Database, parsingErr = parseString(value) + case "schema": + cfg.Schema, parsingErr = parseString(value) + case "role": + cfg.Role, parsingErr = parseString(value) + case "region": + cfg.Region, parsingErr = parseString(value) + case "protocol": + cfg.Protocol, parsingErr = parseString(value) + case "passcode": + cfg.Passcode, parsingErr = parseString(value) + case "port": + cfg.Port, parsingErr = parseInt(value) + case "passcodeinpassword": + cfg.PasscodeInPassword, parsingErr = parseBool(value) + case "clienttimeout": + cfg.ClientTimeout, parsingErr = parseDuration(value) + case "jwtclienttimeout": + cfg.JWTClientTimeout, parsingErr = parseDuration(value) + case "logintimeout": + cfg.LoginTimeout, parsingErr = parseDuration(value) + case "requesttimeout": + cfg.RequestTimeout, parsingErr = parseDuration(value) + case "jwttimeout": + cfg.JWTExpireTimeout, parsingErr = parseDuration(value) + case "externalbrowsertimeout": + cfg.ExternalBrowserTimeout, parsingErr = parseDuration(value) + case "maxretrycount": + cfg.MaxRetryCount, parsingErr = parseInt(value) + case "application": + cfg.Application, parsingErr = parseString(value) + case "authenticator": + v, parsingErr = parseString(value) + if err := checkParsingError(parsingErr, key, value); err != nil { + return err + } + parsingErr = determineAuthenticatorType(cfg, v) + case "insecuremode": + cfg.InsecureMode, parsingErr = parseBool(value) + case "ocspfailopen": + var vv ConfigBool + vv, parsingErr = parseConfigBool(value) + if err := checkParsingError(parsingErr, key, value); err != nil { + return err + } + cfg.OCSPFailOpen = OCSPFailOpenMode(vv) + case "token": + cfg.Token, parsingErr = parseString(value) + case "privatekey": + v, parsingErr = parseString(value) + if err := checkParsingError(parsingErr, key, value); err != nil { + return err + } + block, decodeErr := base64.URLEncoding.DecodeString(v) + if decodeErr != nil { + return &SnowflakeError{ + Number: ErrCodePrivateKeyParseError, + Message: "Base64 decode failed", + } + } + cfg.PrivateKey, parsingErr = parsePKCS8PrivateKey(block) + case "validatedefaultparameters": + cfg.ValidateDefaultParameters, parsingErr = parseConfigBool(value) + case "clientrequestmfatoken": + cfg.ClientRequestMfaToken, parsingErr = parseConfigBool(value) + case "clientstoretemporarycredential": + cfg.ClientStoreTemporaryCredential, parsingErr = parseConfigBool(value) + case "tracing": + cfg.Tracing, parsingErr = parseString(value) + case "tmpdirpath": + cfg.TmpDirPath, parsingErr = parseString(value) + case "disablequerycontextcache": + cfg.DisableQueryContextCache, parsingErr = parseBool(value) + case "includeretryreason": + cfg.IncludeRetryReason, parsingErr = parseConfigBool(value) + case "clientconfigfile": + cfg.ClientConfigFile, parsingErr = parseString(value) + case "disableconsolelogin": + cfg.DisableConsoleLogin, parsingErr = parseConfigBool(value) + case "disablesamlurlcheck": + cfg.DisableSamlURLCheck, parsingErr = parseConfigBool(value) + case "token_file_path": + tokenPath, parsingErr = parseString(value) + if err := checkParsingError(parsingErr, key, value); err != nil { + return err + } v, err := readToken(tokenPath) if err != nil { return err } cfg.Token = v + default: + param, parsingErr := parseString(value) + if err := checkParsingError(parsingErr, key, value); err != nil { + return err + } + cfg.Params[urlDecodeIfNeeded(key)] = ¶m + } + return checkParsingError(parsingErr, key, value) +} + +func checkParsingError(parsingErr error, key string, value interface{}) error { + if parsingErr != nil { + err := &SnowflakeError{ + Number: ErrCodeTomlFileParsingFailed, + Message: errMsgFailedToParseTomlFile, + MessageArgs: []interface{}{key, value}, + } + return err } return nil } @@ -363,6 +237,17 @@ func parseBool(i interface{}) (bool, error) { return vv, nil } +func parseConfigBool(i interface{}) (ConfigBool, error) { + vv, err := parseBool(i) + if err != nil { + return ConfigBoolFalse, err + } + if vv { + return ConfigBoolTrue, nil + } + return ConfigBoolFalse, nil +} + func parseDuration(i interface{}) (time.Duration, error) { v, ok := i.(string) if !ok { @@ -373,11 +258,7 @@ func parseDuration(i interface{}) (time.Duration, error) { t := int64(num) return time.Duration(t * int64(time.Second)), nil } - t, err := strconv.ParseInt(v, 10, 64) - if err != nil { - return time.Duration(0), err - } - return time.Duration(t * int64(time.Second)), nil + return parseTimeout(v) } func readToken(tokenPath string) (string, error) { @@ -385,7 +266,7 @@ func readToken(tokenPath string) (string, error) { tokenPath = "./snowflake/session/token" } if !path.IsAbs(tokenPath) { - snowflakeConfigDir, err := getTomlFilePath(os.Getenv("SNOWFLAKE_HOME")) + snowflakeConfigDir, err := getTomlFilePath(os.Getenv(home)) if err != nil { return "", err } @@ -411,11 +292,7 @@ func parseString(i interface{}) (string, error) { } func getTomlFilePath(filePath string) (string, error) { - if len(filePath) != 0 { - if path.IsAbs(filePath) { - return filePath, nil - } - } else { + if len(filePath) == 0 { homeDir, err := os.UserHomeDir() if err != nil { return "", err diff --git a/connection_configuration_test.go b/connection_configuration_test.go index fc1acfad4..f68f42b08 100644 --- a/connection_configuration_test.go +++ b/connection_configuration_test.go @@ -258,14 +258,14 @@ func TestGetTomlFilePath(t *testing.T) { assertNilF(t, err, "should not have failed") assertEqualF(t, dir, result) - result = "/user/somelocation/b" - if isWindows { - result = "c:\\user\\somelocation\\b" + //Absolute path for windows can be varied depend on which disk the driver is located. + // As a result, this test is available on non-Window machines. + if !isWindows { + result = "/user/somelocation/b" + location = "/user//somelocation///b" + dir, err = getTomlFilePath(location) + assertNilF(t, err, "should not have failed") + assertEqualF(t, dir, result) } - location = "/user//somelocation///b" - dir, err = getTomlFilePath(location) - assertNilF(t, err, "should not have failed") - // result, err = path.Abs(location) - assertNilF(t, err, "should not have failed") - assertEqualF(t, dir, result) + }