From 78e3a6abfb4c41b672bbc9812a1837ad48374fc6 Mon Sep 17 00:00:00 2001 From: skShekhar Date: Thu, 28 Nov 2024 14:45:41 +0530 Subject: [PATCH 1/7] feat(snowflake): support oauth authentication --- sqlconnect/internal/snowflake/config.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sqlconnect/internal/snowflake/config.go b/sqlconnect/internal/snowflake/config.go index 4485687..095f92c 100644 --- a/sqlconnect/internal/snowflake/config.go +++ b/sqlconnect/internal/snowflake/config.go @@ -34,6 +34,9 @@ type Config struct { KeepSessionAlive bool `json:"keepSessionAlive"` UseLegacyMappings bool `json:"useLegacyMappings"` QueryTag string `json:"queryTag"` + Host string `json:"host"` + UseOAuth bool `json:"use_oauth"` + OAuthToken string `json:"oauth_token"` } func (c Config) ConnectionString() (dsn string, err error) { @@ -58,6 +61,10 @@ func (c Config) ConnectionString() (dsn string, err error) { return "", fmt.Errorf("parsing private key: %w", err) } sc.PrivateKey = privateKey + } else if c.UseOAuth { + sc.Authenticator = gosnowflake.AuthTypeOAuth + sc.Host = c.Host + sc.Token = c.OAuthToken } if c.KeepSessionAlive { From 48b0a0239b63532e7ad85f7db77e51ee23bfdce4 Mon Sep 17 00:00:00 2001 From: Arnab Pal Date: Tue, 17 Dec 2024 00:06:25 +0530 Subject: [PATCH 2/7] Updated snowflake config for OAuth --- sqlconnect/internal/snowflake/config.go | 90 +++++++++++++++---------- 1 file changed, 56 insertions(+), 34 deletions(-) diff --git a/sqlconnect/internal/snowflake/config.go b/sqlconnect/internal/snowflake/config.go index 095f92c..817ca8d 100644 --- a/sqlconnect/internal/snowflake/config.go +++ b/sqlconnect/internal/snowflake/config.go @@ -20,6 +20,7 @@ type Config struct { User string `json:"user"` Schema string `json:"schema"` Role string `json:"role"` + Region string `json:"region"` Password string `json:"password"` @@ -40,45 +41,66 @@ type Config struct { } func (c Config) ConnectionString() (dsn string, err error) { - sc := gosnowflake.Config{ - Authenticator: gosnowflake.AuthTypeSnowflake, - User: c.User, - Password: c.Password, - Account: c.Account, - Database: c.DBName, - Warehouse: c.Warehouse, - Schema: c.Schema, - Role: c.Role, - Application: c.Application, - LoginTimeout: c.LoginTimeout, - Params: make(map[string]*string), - } - - if c.UseKeyPairAuth { - sc.Authenticator = gosnowflake.AuthTypeJwt - privateKey, err := c.ParsePrivateKey() + if c.UseOAuth { + sc := gosnowflake.Config{ + Authenticator: gosnowflake.AuthTypeOAuth, + Account: c.Account, + Region: c.Region, + Token: c.OAuthToken, + Warehouse: c.Warehouse, + Schema: c.Schema, + Database: c.DBName, + Host: c.Host, + Protocol: "https", + Port: 443, + KeepSessionAlive: true, + } + dsn, err = gosnowflake.DSN(&sc) if err != nil { - return "", fmt.Errorf("parsing private key: %w", err) + err = fmt.Errorf("creating dsn: %v", err) + } + } else { + sc := gosnowflake.Config{ + Authenticator: gosnowflake.AuthTypeSnowflake, + User: c.User, + Password: c.Password, + Account: c.Account, + Database: c.DBName, + Warehouse: c.Warehouse, + Schema: c.Schema, + Role: c.Role, + Application: c.Application, + LoginTimeout: c.LoginTimeout, + Params: make(map[string]*string), } - sc.PrivateKey = privateKey - } else if c.UseOAuth { - sc.Authenticator = gosnowflake.AuthTypeOAuth - sc.Host = c.Host - sc.Token = c.OAuthToken - } - if c.KeepSessionAlive { - valueTrue := "true" - sc.Params["client_session_keep_alive"] = &valueTrue - } + if c.UseKeyPairAuth { + sc.Authenticator = gosnowflake.AuthTypeJwt + privateKey, err := c.ParsePrivateKey() + if err != nil { + return "", fmt.Errorf("parsing private key: %w", err) + } + sc.PrivateKey = privateKey + } else if c.UseOAuth { + sc.Authenticator = gosnowflake.AuthTypeOAuth + sc.Host = c.Host + sc.Token = c.OAuthToken + sc.User = c.User + } - if c.QueryTag != "" { - sc.Params["query_tag"] = &c.QueryTag - } + if c.KeepSessionAlive { + valueTrue := "true" + sc.Params["client_session_keep_alive"] = &valueTrue + } - dsn, err = gosnowflake.DSN(&sc) - if err != nil { - err = fmt.Errorf("creating dsn: %v", err) + if c.QueryTag != "" { + sc.Params["query_tag"] = &c.QueryTag + } + + dsn, err = gosnowflake.DSN(&sc) + if err != nil { + err = fmt.Errorf("creating dsn: %v", err) + } } return } From 8443f765e7946959da95c7a15618da70a4ace26b Mon Sep 17 00:00:00 2001 From: Arnab Pal Date: Tue, 17 Dec 2024 22:18:45 +0530 Subject: [PATCH 3/7] Logging added to Snowflake connector --- sqlconnect/internal/snowflake/config.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sqlconnect/internal/snowflake/config.go b/sqlconnect/internal/snowflake/config.go index 817ca8d..abf1b6a 100644 --- a/sqlconnect/internal/snowflake/config.go +++ b/sqlconnect/internal/snowflake/config.go @@ -42,6 +42,14 @@ type Config struct { func (c Config) ConnectionString() (dsn string, err error) { if c.UseOAuth { + fmt.Println("sqlconnect: Account: " + c.Account) + fmt.Println("sqlconnect: Region: " + c.Region) + fmt.Println("sqlconnect: Token: " + c.OAuthToken) + fmt.Println("sqlconnect: Warehouse: " + c.Warehouse) + fmt.Println("sqlconnect: Schema: " + c.Schema) + fmt.Println("sqlconnect: Host: " + c.Host) + fmt.Println("sqlconnect: DBName: " + c.DBName) + sc := gosnowflake.Config{ Authenticator: gosnowflake.AuthTypeOAuth, Account: c.Account, From e1bd4b0f97f4b6d835bd80cec93c20f6a0111f5e Mon Sep 17 00:00:00 2001 From: Arnab Pal Date: Thu, 9 Jan 2025 18:06:07 +0530 Subject: [PATCH 4/7] Addressed review comments --- sqlconnect/internal/snowflake/config.go | 113 ++++++++++-------------- 1 file changed, 47 insertions(+), 66 deletions(-) diff --git a/sqlconnect/internal/snowflake/config.go b/sqlconnect/internal/snowflake/config.go index abf1b6a..389a075 100644 --- a/sqlconnect/internal/snowflake/config.go +++ b/sqlconnect/internal/snowflake/config.go @@ -22,12 +22,19 @@ type Config struct { Role string `json:"role"` Region string `json:"region"` + Protocol string `json:"protocol"` // http or https (optional) + Host string `json:"host"` // hostname (optional) + Port int `json:"port"` // port (optional) + Password string `json:"password"` UseKeyPairAuth bool `json:"useKeyPairAuth"` PrivateKey string `json:"privateKey"` PrivateKeyPassphrase string `json:"privateKeyPassphrase"` + UseOAuth bool `json:"useOAuth"` + OAuthToken string `json:"oauthToken"` + Application string `json:"application"` LoginTimeout time.Duration `json:"loginTimeout"` // default: 5m @@ -35,80 +42,54 @@ type Config struct { KeepSessionAlive bool `json:"keepSessionAlive"` UseLegacyMappings bool `json:"useLegacyMappings"` QueryTag string `json:"queryTag"` - Host string `json:"host"` - UseOAuth bool `json:"use_oauth"` - OAuthToken string `json:"oauth_token"` } func (c Config) ConnectionString() (dsn string, err error) { - if c.UseOAuth { - fmt.Println("sqlconnect: Account: " + c.Account) - fmt.Println("sqlconnect: Region: " + c.Region) - fmt.Println("sqlconnect: Token: " + c.OAuthToken) - fmt.Println("sqlconnect: Warehouse: " + c.Warehouse) - fmt.Println("sqlconnect: Schema: " + c.Schema) - fmt.Println("sqlconnect: Host: " + c.Host) - fmt.Println("sqlconnect: DBName: " + c.DBName) - - sc := gosnowflake.Config{ - Authenticator: gosnowflake.AuthTypeOAuth, - Account: c.Account, - Region: c.Region, - Token: c.OAuthToken, - Warehouse: c.Warehouse, - Schema: c.Schema, - Database: c.DBName, - Host: c.Host, - Protocol: "https", - Port: 443, - KeepSessionAlive: true, - } - dsn, err = gosnowflake.DSN(&sc) - if err != nil { - err = fmt.Errorf("creating dsn: %v", err) - } - } else { - sc := gosnowflake.Config{ - Authenticator: gosnowflake.AuthTypeSnowflake, - User: c.User, - Password: c.Password, - Account: c.Account, - Database: c.DBName, - Warehouse: c.Warehouse, - Schema: c.Schema, - Role: c.Role, - Application: c.Application, - LoginTimeout: c.LoginTimeout, - Params: make(map[string]*string), - } + sc := gosnowflake.Config{ + Authenticator: gosnowflake.AuthTypeSnowflake, + User: c.User, + Password: c.Password, + Account: c.Account, + Database: c.DBName, + Warehouse: c.Warehouse, + Schema: c.Schema, + Role: c.Role, + Region: c.Region, + Protocol: c.Protocol, + Host: c.Host, + Port: c.Port, + Application: c.Application, + LoginTimeout: c.LoginTimeout, + Params: make(map[string]*string), + } - if c.UseKeyPairAuth { - sc.Authenticator = gosnowflake.AuthTypeJwt - privateKey, err := c.ParsePrivateKey() - if err != nil { - return "", fmt.Errorf("parsing private key: %w", err) - } - sc.PrivateKey = privateKey - } else if c.UseOAuth { - sc.Authenticator = gosnowflake.AuthTypeOAuth - sc.Host = c.Host - sc.Token = c.OAuthToken - sc.User = c.User + if c.UseKeyPairAuth { + sc.Authenticator = gosnowflake.AuthTypeJwt + privateKey, err := c.ParsePrivateKey() + if err != nil { + return "", fmt.Errorf("parsing private key: %w", err) } + sc.PrivateKey = privateKey + } else if c.UseOAuth { + sc.Authenticator = gosnowflake.AuthTypeOAuth + sc.Token = c.OAuthToken + sc.Port = 443 + sc.Protocol = "https" + } - if c.KeepSessionAlive { - valueTrue := "true" - sc.Params["client_session_keep_alive"] = &valueTrue - } + if c.KeepSessionAlive { + // valueTrue := "true" + // sc.Params["client_session_keep_alive"] = &valueTrue + sc.KeepSessionAlive = true + } - if c.QueryTag != "" { - sc.Params["query_tag"] = &c.QueryTag - } + if c.QueryTag != "" { + sc.Params["query_tag"] = &c.QueryTag + } - dsn, err = gosnowflake.DSN(&sc) - if err != nil { - err = fmt.Errorf("creating dsn: %v", err) - } + dsn, err = gosnowflake.DSN(&sc) + if err != nil { + err = fmt.Errorf("creating dsn: %v", err) } return } From 8046a840302aec0418495585a0f01ca287452e68 Mon Sep 17 00:00:00 2001 From: Arnab Pal Date: Thu, 9 Jan 2025 22:31:06 +0530 Subject: [PATCH 5/7] Removed hardcoded port and protocol values --- sqlconnect/internal/snowflake/config.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sqlconnect/internal/snowflake/config.go b/sqlconnect/internal/snowflake/config.go index 389a075..5dd0ab9 100644 --- a/sqlconnect/internal/snowflake/config.go +++ b/sqlconnect/internal/snowflake/config.go @@ -73,14 +73,11 @@ func (c Config) ConnectionString() (dsn string, err error) { } else if c.UseOAuth { sc.Authenticator = gosnowflake.AuthTypeOAuth sc.Token = c.OAuthToken - sc.Port = 443 - sc.Protocol = "https" } if c.KeepSessionAlive { - // valueTrue := "true" - // sc.Params["client_session_keep_alive"] = &valueTrue - sc.KeepSessionAlive = true + valueTrue := "true" + sc.Params["client_session_keep_alive"] = &valueTrue } if c.QueryTag != "" { From aa9f9fd507f43e9341762b200d777db653c54ea0 Mon Sep 17 00:00:00 2001 From: Aris Tzoumas Date: Mon, 13 Jan 2025 11:12:14 +0200 Subject: [PATCH 6/7] chore: add a test for oauth --- .../internal/snowflake/authentication_test.go | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/sqlconnect/internal/snowflake/authentication_test.go b/sqlconnect/internal/snowflake/authentication_test.go index 2dd71ad..d2dafa5 100644 --- a/sqlconnect/internal/snowflake/authentication_test.go +++ b/sqlconnect/internal/snowflake/authentication_test.go @@ -1,7 +1,13 @@ package snowflake_test import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" "os" + "strings" "testing" "github.com/stretchr/testify/require" @@ -38,4 +44,60 @@ func TestSnowflakeAuthentication(t *testing.T) { defer func() { _ = db.Close() }() require.NoError(t, db.Ping(), "it should be able to ping the database") }) + t.Run("oauth", func(t *testing.T) { + authCode, ok := os.LookupEnv("SNOWFLAKE_TEST_AUTH_OAUTH_CODE") + if !ok { + t.Skip("skipping test due to lack of a test environment") + } + + configJSON, ok := os.LookupEnv("SNOWFLAKE_TEST_ENVIRONMENT_CREDENTIALS") + require.True(t, ok, "it should be able to get the environment credentials") + var conf snowflake.Config + require.NoError(t, json.Unmarshal([]byte(configJSON), &conf), "it should be able to unmarshal the config") + // reset username and password + conf.User = "" + conf.Password = "" + + // Issue a token + var accessToken string + { + var oauthCreds struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + } + oauthCredsJSON, ok := os.LookupEnv("SNOWFLAKE_TEST_AUTH_OAUTH_CREDENTIALS") + require.True(t, ok, "it should be able to get the oauth creds") + require.NoError(t, json.Unmarshal([]byte(oauthCredsJSON), &oauthCreds), "it should be able to unmarshal the oauth creds") + body := url.Values{} + body.Add("redirect_uri", "https://localhost.com") + body.Add("code", authCode) + body.Add("grant_type", "authorization_code") + body.Add("scope", fmt.Sprintf("session:role:%s", conf.Role)) + r, _ := http.NewRequest(http.MethodPost, fmt.Sprintf("https://%s.snowflakecomputing.com/oauth/token-request", conf.Account), strings.NewReader(body.Encode())) + r.Header.Add("Content-Type", "application/x-www-form-urlencoded;charset=UTF-8") + r.SetBasicAuth(oauthCreds.ClientID, oauthCreds.ClientSecret) + resp, err := http.DefaultClient.Do(r) + require.NoError(t, err, "it should be able to issue a token") + defer func() { _ = resp.Body.Close() }() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err, "it should be able to read the response body") + require.Equalf(t, http.StatusOK, resp.StatusCode, "it should be able to issue a token: %s", string(respBody)) + var token struct { + AccessToken string `json:"access_token"` + } + require.NoError(t, json.Unmarshal(respBody, &token), "it should be able to decode the token") + accessToken = token.AccessToken + } + + conf.UseOAuth = true + conf.OAuthToken = accessToken + oauthConfigJSON, err := json.Marshal(conf) + require.NoError(t, err, "it should be able to marshal the config") + db, err := sqlconnect.NewDB(snowflake.DatabaseType, oauthConfigJSON) + require.NoError(t, err, "it should be able to create a new DB") + defer func() { _ = db.Close() }() + require.NoError(t, db.Ping(), "it should be able to ping the database") + require.NoError(t, db.QueryRow("SELECT 1").Err()) + + }) } From 9f905528fbdd8f2e9b27f6c733a7d5b9211e80a6 Mon Sep 17 00:00:00 2001 From: Aris Tzoumas Date: Wed, 15 Jan 2025 12:48:46 +0200 Subject: [PATCH 7/7] chore: make lint --- sqlconnect/internal/snowflake/authentication_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/sqlconnect/internal/snowflake/authentication_test.go b/sqlconnect/internal/snowflake/authentication_test.go index d2dafa5..8f9a4cc 100644 --- a/sqlconnect/internal/snowflake/authentication_test.go +++ b/sqlconnect/internal/snowflake/authentication_test.go @@ -98,6 +98,5 @@ func TestSnowflakeAuthentication(t *testing.T) { defer func() { _ = db.Close() }() require.NoError(t, db.Ping(), "it should be able to ping the database") require.NoError(t, db.QueryRow("SELECT 1").Err()) - }) }