Skip to content

Commit

Permalink
Merge branch 'master' into SNOW-1524204-handle-session-variables-in-dsn
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dszmolka authored Jul 15, 2024
2 parents 37e107b + f2164e5 commit 02dfde5
Show file tree
Hide file tree
Showing 8 changed files with 380 additions and 57 deletions.
14 changes: 2 additions & 12 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ import (
"io"
"net/http"
"net/url"
"os"
"regexp"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -67,8 +65,6 @@ const (
executionTypeStatement string = "statement"
)

const privateLinkSuffix = "privatelink.snowflakecomputing.com"

type snowflakeConn struct {
ctx context.Context
cfg *Config
Expand Down Expand Up @@ -777,14 +773,8 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err
// use the custom transport
st = sc.cfg.Transporter
}
if strings.HasSuffix(sc.cfg.Host, privateLinkSuffix) {
if err := sc.setupOCSPPrivatelink(sc.cfg.Application, sc.cfg.Host); err != nil {
return nil, err
}
} else {
if _, set := os.LookupEnv(cacheServerURLEnv); set {
os.Unsetenv(cacheServerURLEnv)
}
if err = setupOCSPEnvVars(sc.ctx, sc.cfg.Host); err != nil {
return nil, err
}
var tokenAccessor TokenAccessor
if sc.cfg.TokenAccessor != nil {
Expand Down
93 changes: 85 additions & 8 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,29 @@ func fetchResultByQueryID(
return nil
}

func TestPrivateLink(t *testing.T) {
func TestIsPrivateLink(t *testing.T) {
for _, tc := range []struct {
host string
isPrivatelink bool
}{
{"testaccount.us-east-1.snowflakecomputing.com", false},
{"testaccount-no-privatelink.snowflakecomputing.com", false},
{"testaccount.us-east-1.privatelink.snowflakecomputing.com", true},
{"testaccount.cn-region.snowflakecomputing.cn", false},
{"testaccount.cn-region.privaTELINk.snowflakecomputing.cn", true},
{"testaccount.some-region.privatelink.snowflakecomputing.mil", true},
{"testaccount.us-east-1.privatelink.snowflakecOMPUTING.com", true},
{"snowhouse.snowflakecomputing.xyz", false},
{"snowhouse.privatelink.snowflakecomputing.xyz", true},
{"snowhouse.PRIVATELINK.snowflakecomputing.xyz", true},
} {
t.Run(tc.host, func(t *testing.T) {
assertEqualE(t, isPrivateLink(tc.host), tc.isPrivatelink)
})
}
}

func TestBuildPrivatelinkConn(t *testing.T) {
if _, err := buildSnowflakeConn(context.Background(), Config{
Account: "testaccount",
User: "testuser",
Expand All @@ -486,15 +508,70 @@ func TestPrivateLink(t *testing.T) {
}); err != nil {
t.Error(err)
}
defer func() {
os.Unsetenv(cacheServerURLEnv)
os.Unsetenv(ocspRetryURLEnv)
}()

ocspURL := os.Getenv(cacheServerURLEnv)
expectedURL := "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/ocsp_response_cache.json"
if ocspURL != expectedURL {
t.Errorf("expected: %v, got: %v", expectedURL, ocspURL)
}
assertEqualE(t, ocspURL, "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/ocsp_response_cache.json")
retryURL := os.Getenv(ocspRetryURLEnv)
expectedURL = "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/retry/%v/%v"
if retryURL != expectedURL {
t.Errorf("expected: %v, got: %v", expectedURL, retryURL)
assertEqualE(t, retryURL, "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/retry/%v/%v")
}

func TestOcspEnvVarsSetup(t *testing.T) {
ctx := context.Background()
for _, tc := range []struct {
host string
cacheURL string
privateLinkRetryURL string
}{
{
host: "testaccount.us-east-1.snowflakecomputing.com",
cacheURL: "", // no privatelink, default ocsp cache URL, no need to setup env vars
privateLinkRetryURL: "",
},
{
host: "testaccount-no-privatelink.snowflakecomputing.com",
cacheURL: "", // no privatelink, default ocsp cache URL, no need to setup env vars
privateLinkRetryURL: "",
},
{
host: "testaccount.us-east-1.privatelink.snowflakecomputing.com",
cacheURL: "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/ocsp_response_cache.json",
privateLinkRetryURL: "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/retry/%v/%v",
},
{
host: "testaccount.cn-region.snowflakecomputing.cn",
cacheURL: "http://ocsp.testaccount.cn-region.snowflakecomputing.cn/ocsp_response_cache.json",
privateLinkRetryURL: "", // not a privatelink env, no need to setup retry URL
},
{
host: "testaccount.cn-region.privaTELINk.snowflakecomputing.cn",
cacheURL: "http://ocsp.testaccount.cn-region.privatelink.snowflakecomputing.cn/ocsp_response_cache.json",
privateLinkRetryURL: "http://ocsp.testaccount.cn-region.privatelink.snowflakecomputing.cn/retry/%v/%v",
},
{
host: "testaccount.some-region.privatelink.snowflakecomputing.mil",
cacheURL: "http://ocsp.testaccount.some-region.privatelink.snowflakecomputing.mil/ocsp_response_cache.json",
privateLinkRetryURL: "http://ocsp.testaccount.some-region.privatelink.snowflakecomputing.mil/retry/%v/%v",
},
} {
t.Run(tc.host, func(t *testing.T) {
if err := setupOCSPEnvVars(ctx, tc.host); err != nil {
t.Errorf("error during OCSP env vars setup; %v", err)
}
defer func() {
os.Unsetenv(cacheServerURLEnv)
os.Unsetenv(ocspRetryURLEnv)
}()

cacheURLFromEnv := os.Getenv(cacheServerURLEnv)
assertEqualE(t, cacheURLFromEnv, tc.cacheURL)
retryURL := os.Getenv(ocspRetryURLEnv)
assertEqualE(t, retryURL, tc.privateLinkRetryURL)

})
}
}

Expand Down
37 changes: 33 additions & 4 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,20 +288,49 @@ func populateChunkDownloader(
}
}

func (sc *snowflakeConn) setupOCSPPrivatelink(app string, host string) error {
ocspCacheServer := fmt.Sprintf("http://ocsp.%v/ocsp_response_cache.json", host)
logger.WithContext(sc.ctx).Debugf("OCSP Cache Server for Privatelink: %v\n", ocspCacheServer)
func setupOCSPEnvVars(ctx context.Context, host string) error {
host = strings.ToLower(host)
if isPrivateLink(host) {
if err := setupOCSPPrivatelink(ctx, host); err != nil {
return err
}
} else if !strings.HasSuffix(host, defaultDomain) {
ocspCacheServer := fmt.Sprintf("http://ocsp.%v/%v", host, cacheFileBaseName)
logger.WithContext(ctx).Debugf("OCSP Cache Server for %v: %v\n", host, ocspCacheServer)
if err := os.Setenv(cacheServerURLEnv, ocspCacheServer); err != nil {
return err
}
} else {
if _, set := os.LookupEnv(cacheServerURLEnv); set {
os.Unsetenv(cacheServerURLEnv)
}
}
return nil
}

func setupOCSPPrivatelink(ctx context.Context, host string) error {
ocspCacheServer := fmt.Sprintf("http://ocsp.%v/%v", host, cacheFileBaseName)
logger.WithContext(ctx).Debugf("OCSP Cache Server for Privatelink: %v\n", ocspCacheServer)
if err := os.Setenv(cacheServerURLEnv, ocspCacheServer); err != nil {
return err
}
ocspRetryHostTemplate := fmt.Sprintf("http://ocsp.%v/retry/", host) + "%v/%v"
logger.WithContext(sc.ctx).Debugf("OCSP Retry URL for Privatelink: %v\n", ocspRetryHostTemplate)
logger.WithContext(ctx).Debugf("OCSP Retry URL for Privatelink: %v\n", ocspRetryHostTemplate)
if err := os.Setenv(ocspRetryURLEnv, ocspRetryHostTemplate); err != nil {
return err
}
return nil
}

/**
* We can only tell if private link is enabled for certain hosts when the hostname contains the subdomain
* 'privatelink.snowflakecomputing.' but we don't have a good way of telling if a private link connection is
* expected for internal stages for example.
*/
func isPrivateLink(host string) bool {
return strings.Contains(strings.ToLower(host), ".privatelink.snowflakecomputing.")
}

func isStatementContext(ctx context.Context) bool {
v := ctx.Value(executionType)
return v == executionTypeStatement
Expand Down
7 changes: 7 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"database/sql/driver"
"os"
"runtime"
"strings"
"sync"
)

Expand Down Expand Up @@ -41,6 +42,12 @@ func (d SnowflakeDriver) OpenWithConfig(ctx context.Context, config Config) (dri
return nil, err
}

if strings.HasSuffix(strings.ToLower(config.Host), cnDomain) {
logger.WithContext(ctx).Info("Connecting to CHINA Snowflake domain")
} else {
logger.WithContext(ctx).Info("Connecting to GLOBAL Snowflake domain")
}

if err = authenticateWithConfig(sc); err != nil {
return nil, err
}
Expand Down
96 changes: 69 additions & 27 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ const (
defaultExternalBrowserTimeout = 120 * time.Second // Timeout for external browser login
defaultMaxRetryCount = 7 // specifies maximum number of subsequent retries
defaultDomain = ".snowflakecomputing.com"
cnDomain = ".snowflakecomputing.cn"
topLevelDomainPrefix = ".snowflakecomputing." // used to extract the domain from host
)

// ConfigBool is a type to represent true or false in the Config
Expand Down Expand Up @@ -135,26 +137,26 @@ func (c *Config) ocspMode() string {

// DSN constructs a DSN for Snowflake db.
func DSN(cfg *Config) (dsn string, err error) {
if cfg.Region == "us-west-2" {
cfg.Region = ""
}
// in case account includes region
region, posDot := extractRegionFromAccount(cfg.Account)
if region != "" {
if cfg.Region != "" {
return "", errRegionConflict()
}
cfg.Region = region
cfg.Account = cfg.Account[:posDot]
}
hasHost := true
if cfg.Host == "" {
hasHost = false
if cfg.Region == "us-west-2" {
cfg.Region = ""
}
if cfg.Region == "" {
cfg.Host = cfg.Account + defaultDomain
} else {
cfg.Host = cfg.Account + "." + cfg.Region + defaultDomain
}
}
// in case account includes region
posDot := strings.Index(cfg.Account, ".")
if posDot > 0 {
if cfg.Region != "" {
return "", errInvalidRegion()
cfg.Host = buildHostFromAccountAndRegion(cfg.Account, cfg.Region)
}
cfg.Region = cfg.Account[posDot+1:]
cfg.Account = cfg.Account[:posDot]
}
err = fillMissingConfigParameters(cfg)
if err != nil {
Expand Down Expand Up @@ -374,7 +376,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
return
}
}
if cfg.Account == "" && strings.HasSuffix(cfg.Host, defaultDomain) {
if cfg.Account == "" && hostIncludesTopLevelDomain(cfg.Host) {
posDot := strings.Index(cfg.Host, ".")
if posDot > 0 {
cfg.Account = cfg.Host[:posDot]
Expand Down Expand Up @@ -428,7 +430,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
func fillMissingConfigParameters(cfg *Config) error {
posDash := strings.LastIndex(cfg.Account, "-")
if posDash > 0 {
if strings.Contains(cfg.Host, ".global.") {
if strings.Contains(strings.ToLower(cfg.Host), ".global.") {
cfg.Account = cfg.Account[:posDash]
}
}
Expand All @@ -453,19 +455,24 @@ func fillMissingConfigParameters(cfg *Config) error {
cfg.Region = strings.Trim(cfg.Region, " ")
if cfg.Region != "" {
// region is specified but not included in Host
i := strings.Index(cfg.Host, defaultDomain)
domain, i := extractDomainFromHost(cfg.Host)
if i >= 1 {
hostPrefix := cfg.Host[0:i]
if !strings.HasSuffix(hostPrefix, cfg.Region) {
cfg.Host = hostPrefix + "." + cfg.Region + defaultDomain
cfg.Host = fmt.Sprintf("%v.%v%v", hostPrefix, cfg.Region, domain)
}
}
}
if cfg.Host == "" {
if cfg.Region != "" {
cfg.Host = cfg.Account + "." + cfg.Region + defaultDomain
cfg.Host = cfg.Account + "." + cfg.Region + getDomainBasedOnRegion(cfg.Region)
} else {
cfg.Host = cfg.Account + defaultDomain
region, _ := extractRegionFromAccount(cfg.Account)
if region != "" {
cfg.Host = cfg.Account + getDomainBasedOnRegion(region)
} else {
cfg.Host = cfg.Account + defaultDomain
}
}
}
if cfg.LoginTimeout == 0 {
Expand Down Expand Up @@ -505,7 +512,8 @@ func fillMissingConfigParameters(cfg *Config) error {
cfg.IncludeRetryReason = ConfigBoolTrue
}

if strings.HasSuffix(cfg.Host, defaultDomain) && len(cfg.Host) == len(defaultDomain) {
domain, _ := extractDomainFromHost(cfg.Host)
if len(cfg.Host) == len(domain) {
return &SnowflakeError{
Number: ErrCodeFailedToParseHost,
Message: errMsgFailedToParseHost,
Expand All @@ -515,6 +523,38 @@ func fillMissingConfigParameters(cfg *Config) error {
return nil
}

func extractDomainFromHost(host string) (domain string, index int) {
i := strings.LastIndex(strings.ToLower(host), topLevelDomainPrefix)
if i >= 1 {
domain = host[i:]
return domain, i
}
return "", i
}

func getDomainBasedOnRegion(region string) string {
if strings.HasPrefix(strings.ToLower(region), "cn-") {
return cnDomain
}
return defaultDomain
}

func extractRegionFromAccount(account string) (region string, posDot int) {
posDot = strings.Index(strings.ToLower(account), ".")
if posDot > 0 {
return account[posDot+1:], posDot
}
return "", posDot
}

func hostIncludesTopLevelDomain(host string) bool {
return strings.Contains(strings.ToLower(host), topLevelDomainPrefix)
}

func buildHostFromAccountAndRegion(account, region string) string {
return account + "." + region + getDomainBasedOnRegion(region)
}

func authRequiresUser(cfg *Config) bool {
return cfg.Authenticator != AuthTypeOAuth &&
cfg.Authenticator != AuthTypeTokenAccessor &&
Expand All @@ -528,18 +568,20 @@ func authRequiresPassword(cfg *Config) bool {
cfg.Authenticator != AuthTypeJwt
}

// transformAccountToHost transforms host to account name
// transformAccountToHost transforms account to host
func transformAccountToHost(cfg *Config) (err error) {
if cfg.Port == 0 && !strings.HasSuffix(cfg.Host, defaultDomain) && cfg.Host != "" {
if cfg.Port == 0 && cfg.Host != "" && !hostIncludesTopLevelDomain(cfg.Host) {
// account name is specified instead of host:port
cfg.Account = cfg.Host
cfg.Host = cfg.Account + defaultDomain
cfg.Port = 443
posDot := strings.Index(cfg.Account, ".")
if posDot > 0 {
cfg.Region = cfg.Account[posDot+1:]
region, posDot := extractRegionFromAccount(cfg.Account)
if region != "" {
cfg.Region = region
cfg.Account = cfg.Account[:posDot]
cfg.Host = buildHostFromAccountAndRegion(cfg.Account, cfg.Region)
} else {
cfg.Host = cfg.Account + defaultDomain
}
cfg.Port = 443
}
return nil
}
Expand Down
Loading

0 comments on commit 02dfde5

Please sign in to comment.