From 3168c62a086333f731e891ea46ffbbdfa0313efb Mon Sep 17 00:00:00 2001 From: sfc-gh-ext-simba-lb Date: Fri, 28 Jul 2023 08:23:54 -0700 Subject: [PATCH] cover critical areas test code coverage for logger, dsn and connection --- connection_test.go | 39 ++++++++ dsn_test.go | 75 ++++++++++++++-- log_test.go | 219 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 328 insertions(+), 5 deletions(-) create mode 100644 log_test.go diff --git a/connection_test.go b/connection_test.go index f32141032..08962b643 100644 --- a/connection_test.go +++ b/connection_test.go @@ -657,3 +657,42 @@ func TestQueryContextError(t *testing.T) { t.Fatalf("should be snowflake error. err: %v", err) } } + +func TestPrepareQuery(t *testing.T) { + ctx := context.Background() + config, err := ParseDSN(dsn) + if err != nil { + t.Error(err) + } + sc, err := buildSnowflakeConn(ctx, *config) + if err != nil { + t.Error(err) + } + if err = authenticateWithConfig(sc); err != nil { + t.Error(err) + } + _, err = sc.Prepare("SELECT 1") + + if err != nil { + t.Fatalf("failed to prepare query. err: %v", err) + } +} + +func TestBeginCreatesTransaction(t *testing.T) { + ctx := context.Background() + config, err := ParseDSN(dsn) + if err != nil { + t.Error(err) + } + sc, err := buildSnowflakeConn(ctx, *config) + if err != nil { + t.Error(err) + } + if err = authenticateWithConfig(sc); err != nil { + t.Error(err) + } + tx, _ := sc.Begin() + if tx == nil { + t.Fatal("should have created a transaction with connection") + } +} diff --git a/dsn_test.go b/dsn_test.go index 0debed708..de983e7f2 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -727,6 +727,14 @@ func TestParseDSN(t *testing.T) { t.Fatalf("%d: Failed to match ExternalBrowserTimeout. expected: %v, got: %v", i, test.config.ExternalBrowserTimeout, cfg.ExternalBrowserTimeout) } + if test.config.ClientStoreTemporaryCredential != cfg.ClientStoreTemporaryCredential { + t.Fatalf("%d: Failed to match ClientStoreTemporaryCredential. expected: %v, got: %v", + i, test.config.ClientStoreTemporaryCredential, cfg.ClientStoreTemporaryCredential) + } + if test.config.ClientRequestMfaToken != cfg.ClientRequestMfaToken { + t.Fatalf("%d: Failed to match ClientRequestMfaToken. expected: %v, got: %v", + i, test.config.ClientRequestMfaToken, cfg.ClientRequestMfaToken) + } case test.err != nil: driverErrE, okE := test.err.(*SnowflakeError) driverErrG, okG := err.(*SnowflakeError) @@ -900,12 +908,13 @@ func TestDSN(t *testing.T) { }, { cfg: &Config{ - User: "u", - Password: "p", - Account: "a", - Authenticator: AuthTypeExternalBrowser, + User: "u", + Password: "p", + Account: "a", + Authenticator: AuthTypeExternalBrowser, + ClientStoreTemporaryCredential: ConfigBoolTrue, }, - dsn: "u:p@a.snowflakecomputing.com:443?authenticator=externalbrowser&ocspFailOpen=true&validateDefaultParameters=true", + dsn: "u:p@a.snowflakecomputing.com:443?authenticator=externalbrowser&clientStoreTemporaryCredential=true&ocspFailOpen=true&validateDefaultParameters=true", }, { cfg: &Config{ @@ -1023,6 +1032,62 @@ func TestDSN(t *testing.T) { }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=300&jwtClientTimeout=60&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + ClientTimeout: 300 * time.Second, + JWTExpireTimeout: 30 * time.Second, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=300&jwtTimeout=30&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Protocol: "http", + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true&protocol=http®ion=b.c&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Tracing: "debug", + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&tracing=debug&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Authenticator: AuthTypeUsernamePasswordMFA, + ClientRequestMfaToken: ConfigBoolTrue, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=username_password_mfa&clientRequestMfaToken=true&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Warehouse: "wh", + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&validateDefaultParameters=true&warehouse=wh", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Token: "t", + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&token=t&validateDefaultParameters=true", + }, } for _, test := range testcases { dsn, err := DSN(test.cfg) diff --git a/log_test.go b/log_test.go new file mode 100644 index 000000000..b585c6411 --- /dev/null +++ b/log_test.go @@ -0,0 +1,219 @@ +// Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. + +package gosnowflake + +import ( + "bytes" + "errors" + "strings" + "testing" + "time" + + rlog "github.com/sirupsen/logrus" +) + +func createTestLogger() defaultLogger { + var rLogger = rlog.New() + var ret = defaultLogger{inner: rLogger} + return ret +} + +func TestIsLevelEnabled(t *testing.T) { + logger := createTestLogger() + logger.SetLevel(rlog.TraceLevel) + if !logger.IsLevelEnabled(rlog.TraceLevel) { + t.Fatalf("log level should be trace but is %v", logger.GetLevel()) + } +} + +func TestLogFunction(t *testing.T) { + logger := createTestLogger() + buf := &bytes.Buffer{} + var formatter = rlog.TextFormatter{CallerPrettyfier: SFCallerPrettyfier} + logger.SetFormatter(&formatter) + logger.SetReportCaller(true) + logger.SetOutput(buf) + logger.SetLevel(rlog.TraceLevel) + + logger.Log(rlog.TraceLevel, "hello world") + logger.Logf(rlog.TraceLevel, "log %v", "format") + logger.Logln(rlog.TraceLevel, "log line") + + var strbuf = buf.String() + if !strings.Contains(strbuf, "hello world") && + !strings.Contains(strbuf, "log format") && + !strings.Contains(strbuf, "log line") { + t.Fatalf("unexpected output in log %v", strbuf) + } +} + +func TestSetLogLevelError(t *testing.T) { + logger := CreateDefaultLogger() + err := logger.SetLogLevel("unknown") + if err == nil { + t.Fatal("should have thrown an error") + } +} + +func TestDefaultLogLevel(t *testing.T) { + logger := CreateDefaultLogger() + buf := &bytes.Buffer{} + logger.SetOutput(buf) + SetLogger(&logger) + + // default logger level is info + logger.Info("info") + logger.Infof("info%v", "f") + logger.Infoln("infoln") + + // debug and trace won't write to log since they are higher than info level + logger.Debug("debug") + logger.Debugf("debug%v", "f") + logger.Debugln("debugln") + + logger.Trace("trace") + logger.Tracef("trace%v", "f") + logger.Traceln("traceln") + + // print, warning and error should write to log since they are lower than info + logger.Print("print") + logger.Printf("print%v", "f") + logger.Println("println") + + logger.Warn("warn") + logger.Warnf("warn%v", "f") + logger.Warnln("warnln") + + logger.Warning("warning") + logger.Warningf("warning%v", "f") + logger.Warningln("warningln") + + logger.Error("error") + logger.Errorf("error%v", "f") + logger.Errorln("errorln") + + // verify output + var strbuf = buf.String() + + if strings.Contains(strbuf, "debug") && + strings.Contains(strbuf, "trace") && + !strings.Contains(strbuf, "info") && + !strings.Contains(strbuf, "print") && + !strings.Contains(strbuf, "warn") && + !strings.Contains(strbuf, "warning") && + !strings.Contains(strbuf, "error") { + t.Fatalf("unexpected output in log: %v", strbuf) + } +} + +func TestLogSetLevel(t *testing.T) { + logger := GetLogger() + buf := &bytes.Buffer{} + logger.SetOutput(buf) + logger.SetLogLevel("trace") + + logger.Trace("should print at trace level") + logger.Debug("should print at debug level") + + var strbuf = buf.String() + + if !strings.Contains(strbuf, "trace level") && + !strings.Contains(strbuf, "debug level") { + t.Fatalf("unexpected output in log: %v", strbuf) + } +} + +func TestLogWithError(t *testing.T) { + logger := CreateDefaultLogger() + buf := &bytes.Buffer{} + logger.SetOutput(buf) + + err := errors.New("error") + logger.WithError(err).Info("hello world") + + var strbuf = buf.String() + if !strings.Contains(strbuf, "error=error") { + t.Fatalf("unexpected output in log: %v", strbuf) + } +} + +func TestLogWithTime(t *testing.T) { + logger := createTestLogger() + buf := &bytes.Buffer{} + logger.SetOutput(buf) + + ti := time.Now() + logger.WithTime(ti).Info("hello") + time.Sleep(3 * time.Second) + + var strbuf = buf.String() + if !strings.Contains(strbuf, ti.Format(time.RFC3339)) { + t.Fatalf("unexpected string in output: %v", strbuf) + } +} + +func TestLogWithField(t *testing.T) { + logger := CreateDefaultLogger() + buf := &bytes.Buffer{} + logger.SetOutput(buf) + + logger.WithField("field", "test").Info("hello") + var strbuf = buf.String() + if !strings.Contains(strbuf, "field=test") { + t.Fatalf("unexpected string in output: %v", strbuf) + } +} + +func TestLogLevelFunctions(t *testing.T) { + logger := createTestLogger() + buf := &bytes.Buffer{} + logger.SetOutput(buf) + + logger.TraceFn(func() []interface{} { + return []interface{}{ + "trace function", + } + }) + + logger.DebugFn(func() []interface{} { + return []interface{}{ + "debug function", + } + }) + + logger.InfoFn(func() []interface{} { + return []interface{}{ + "info function", + } + }) + + logger.PrintFn(func() []interface{} { + return []interface{}{ + "print function", + } + }) + + logger.WarningFn(func() []interface{} { + return []interface{}{ + "warning function", + } + }) + + logger.ErrorFn(func() []interface{} { + return []interface{}{ + "error function", + } + }) + + // check that info, print, warning and error were outputted to the log. + var strbuf = buf.String() + + if strings.Contains(strbuf, "debug") && + strings.Contains(strbuf, "trace") && + !strings.Contains(strbuf, "info") && + !strings.Contains(strbuf, "print") && + !strings.Contains(strbuf, "warning") && + !strings.Contains(strbuf, "error") { + t.Fatalf("unexpected output in log: %v", strbuf) + } +}