Skip to content

Commit

Permalink
Add timeout for the external browser authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pbulawa committed Jun 28, 2023
1 parent ee57d93 commit 5cf7d47
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 21 deletions.
3 changes: 2 additions & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,8 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.cfg.Application,
sc.cfg.Account,
sc.cfg.User,
sc.cfg.Password)
sc.cfg.Password,
sc.cfg.ExternalBrowserTimeout)
if err != nil {
sc.cleanup()
return err
Expand Down
46 changes: 38 additions & 8 deletions authexternalbrowser.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
Expand All @@ -14,6 +15,7 @@ import (
"net/url"
"strconv"
"strings"
"time"

"github.com/pkg/browser"
)
Expand Down Expand Up @@ -165,6 +167,34 @@ func getTokenFromResponse(response string) (string, error) {
return token, nil
}

type authenticateByExternalBrowserResult struct {
escapedSamlResponse []byte
proofKey []byte
err error
}

func authenticateByExternalBrowser(
ctx context.Context,
sr *snowflakeRestful,
authenticator string,
application string,
account string,
user string,
password string,
externalBrowserTimeout time.Duration,
) ([]byte, []byte, error) {
resultChan := make(chan authenticateByExternalBrowserResult, 1)
go func() {
resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, password)
}()
select {
case <-time.After(externalBrowserTimeout):
return nil, nil, errors.New("authentication timed out")
case result := <-resultChan:
return result.escapedSamlResponse, result.proofKey, result.err
}
}

// Authentication by an external browser takes place via the following:
// - the golang snowflake driver communicates to Snowflake that the user wishes to
// authenticate via external browser
Expand All @@ -174,30 +204,30 @@ func getTokenFromResponse(response string) (string, error) {
// - user authenticates at the IDP, and is redirected to Snowflake
// - Snowflake directs the user back to the driver
// - authenticate is complete!
func authenticateByExternalBrowser(
func doAuthenticateByExternalBrowser(
ctx context.Context,
sr *snowflakeRestful,
authenticator string,
application string,
account string,
user string,
password string,
) ([]byte, []byte, error) {
) authenticateByExternalBrowserResult {
l, err := bindToPort()
if err != nil {
return nil, nil, err
return authenticateByExternalBrowserResult{nil, nil, err}
}
defer l.Close()

callbackPort := l.Addr().(*net.TCPAddr).Port
idpURL, proofKey, err := getIdpURLProofKey(
ctx, sr, authenticator, application, account, callbackPort)
if err != nil {
return nil, nil, err
return authenticateByExternalBrowserResult{nil, nil, err}
}

if err = openBrowser(idpURL); err != nil {
return nil, nil, err
return authenticateByExternalBrowserResult{nil, nil, err}
}

encodedSamlResponseChan := make(chan string)
Expand Down Expand Up @@ -253,13 +283,13 @@ func authenticateByExternalBrowser(
errFromGoroutine = <-errChan

if errFromGoroutine != nil {
return nil, nil, errFromGoroutine
return authenticateByExternalBrowserResult{nil, nil, errFromGoroutine}
}

escapedSamlResponse, err := url.QueryUnescape(encodedSamlResponse)
if err != nil {
logger.WithContext(ctx).Errorf("unable to unescape saml response. err: %v", err)
return nil, nil, err
return authenticateByExternalBrowserResult{nil, nil, err}
}
return []byte(escapedSamlResponse), []byte(proofKey), nil
return authenticateByExternalBrowserResult{[]byte(escapedSamlResponse), []byte(proofKey), nil}
}
27 changes: 24 additions & 3 deletions authexternalbrowser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,25 @@ func TestUnitAuthenticateByExternalBrowser(t *testing.T) {
account := "testaccount"
user := "u"
password := "p"
timeout := defaultExternalBrowserTimeout
sr := &snowflakeRestful{
Protocol: "https",
Host: "abc.com",
Port: 443,
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password)
_, _, err := authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password, timeout)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFail
_, _, err = authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password)
_, _, err = authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password, timeout)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFailWithCode
_, _, err = authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password)
_, _, err = authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password, timeout)
if err == nil {
t.Fatal("should have failed.")
}
Expand All @@ -112,3 +113,23 @@ func TestUnitAuthenticateByExternalBrowser(t *testing.T) {
t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeFailedToConnect, driverErr.Number)
}
}

func TestAuthenticationTimeout(t *testing.T) {
authenticator := "externalbrowser"
application := "testapp"
account := "testaccount"
user := "u"
password := "p"
timeout := 0 * time.Second
sr := &snowflakeRestful{
Protocol: "https",
Host: "abc.com",
Port: 443,
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password, timeout)
if err.Error() != "authentication timed out" {
t.Fatal("should have timed out")
}
}
23 changes: 14 additions & 9 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ import (
)

const (
defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response
defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout
defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout
defaultJWTTimeout = 60 * time.Second
defaultDomain = ".snowflakecomputing.com"
defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response
defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout
defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout
defaultJWTTimeout = 60 * time.Second
defaultExternalBrowserTimeout = 120 * time.Second // Timeout for external browser login
defaultDomain = ".snowflakecomputing.com"
)

// ConfigBool is a type to represent true or false in the Config
Expand Down Expand Up @@ -66,10 +67,11 @@ type Config struct {

OktaURL *url.URL

LoginTimeout time.Duration // Login retry timeout EXCLUDING network roundtrip and read out http response
RequestTimeout time.Duration // request retry timeout EXCLUDING network roundtrip and read out http response
JWTExpireTimeout time.Duration // JWT expire after timeout
ClientTimeout time.Duration // Timeout for network round trip + read out http response
LoginTimeout time.Duration // Login retry timeout EXCLUDING network roundtrip and read out http response
RequestTimeout time.Duration // request retry timeout EXCLUDING network roundtrip and read out http response
JWTExpireTimeout time.Duration // JWT expire after timeout
ClientTimeout time.Duration // Timeout for network round trip + read out http response
ExternalBrowserTimeout time.Duration // Timeout for external browser login

Application string // application name.
InsecureMode bool // driver doesn't check certificate revocation status
Expand Down Expand Up @@ -429,6 +431,9 @@ func fillMissingConfigParameters(cfg *Config) error {
if cfg.ClientTimeout == 0 {
cfg.ClientTimeout = defaultClientTimeout
}
if cfg.ExternalBrowserTimeout == 0 {
cfg.ExternalBrowserTimeout = defaultExternalBrowserTimeout
}
if strings.Trim(cfg.Application, " ") == "" {
cfg.Application = clientType
}
Expand Down

0 comments on commit 5cf7d47

Please sign in to comment.