diff --git a/pkg/log/log.go b/pkg/log/log.go index 22459bd228..a03ee790b1 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -22,30 +22,28 @@ const ( type logger struct { level Level - entry *logrus.Entry ctx context.Context err error } -func Info() Logger { +// common logger implementation used in the library +var log = logrus.New() +func Info() Logger { return &logger{ level: InfoLevel, - entry: logrus.NewEntry(logrus.New()), } } func Error() Logger { return &logger{ level: ErrorLevel, - entry: logrus.NewEntry(logrus.New()), } } func Debug() Logger { return &logger{ level: DebugLevel, - entry: logrus.NewEntry(logrus.New()), } } @@ -54,12 +52,12 @@ func Print(msg string) { Info().Print(msg) } -func WithContext(ctx context.Context) { - Info().WithContext(ctx) +func WithContext(ctx context.Context) Logger { + return Info().WithContext(ctx) } -func WithError(err error) { - Info().WithError(err) +func WithError(err error) Logger { + return Info().WithError(err) } func (l *logger) Print(msg string) { @@ -69,27 +67,19 @@ func (l *logger) Print(msg string) { logFields[cf.Key()] = cf.Value() } } - - l.entry = l.entry.WithFields(logFields) - - switch l.level { - case InfoLevel: - l.entry.Info(msg) - case ErrorLevel: - l.entry.Error(msg) - case DebugLevel: - l.entry.Debug(msg) + entry := log.WithFields(logFields) + if l.err != nil { + entry = entry.WithError(l.err) } + entry.Log(logrus.Level(l.level), msg) } func (l *logger) WithContext(ctx context.Context) Logger { l.ctx = ctx - l.entry = l.entry.WithContext(ctx) return l } func (l *logger) WithError(err error) Logger { l.err = err - l.entry = l.entry.WithError(err) return l } diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go new file mode 100644 index 0000000000..0b2660fd96 --- /dev/null +++ b/pkg/log/log_test.go @@ -0,0 +1,98 @@ +package log + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/kanisterio/kanister/pkg/field" + "github.com/sirupsen/logrus" + . "gopkg.in/check.v1" +) + +const ( + infoLevelStr = "info" + errorLevelStr = "error" + debugLevelStr = "debug" +) + +type LogSuite struct{} + +var _ = Suite(&LogSuite{}) + +func Test(t *testing.T) { + TestingT(t) +} + +func (s *LogSuite) TestWithNilError(c *C) { + log.SetFormatter(&logrus.JSONFormatter{TimestampFormat: time.RFC3339Nano}) + // Should not panic + WithError(nil).Print("Message") +} + +func (s *LogSuite) TestWithNilContext(c *C) { + log.SetFormatter(&logrus.JSONFormatter{TimestampFormat: time.RFC3339Nano}) + // Should not panic + WithContext(nil).Print("Message") +} + +func (s *LogSuite) TestLogMessage(c *C) { + const text = "Some useful text." + testLogMessage(c, text, Print) +} + +func (s *LogSuite) TestLogWithError(c *C) { + const text = "My error message" + err := errors.New("test error") + entry := testLogMessage(c, text, WithError(err).Print) + c.Assert(entry["error"], Equals, err.Error()) + c.Assert(entry["level"], Equals, infoLevelStr) +} + +func (s *LogSuite) TestLogWithContext(c *C) { + const text = "My error message" + ctx := context.Background() + entry := testLogMessage(c, text, WithContext(ctx).Print) + c.Assert(entry["level"], Equals, infoLevelStr) + // Error should not be set in the log entry + c.Assert(entry["error"], Equals, nil) +} + +func (s *LogSuite) TestLogWithContextFields(c *C) { + const text = "My error message" + ctx := field.Context(context.Background(), "key", "value") + entry := testLogMessage(c, text, WithContext(ctx).Print) + c.Assert(entry["level"], Equals, infoLevelStr) + // Error should not be set in the log entry + c.Assert(entry["error"], Equals, nil) + // A field with "key" should be set in the log entry + c.Assert(entry["key"], Equals, "value") +} + +func (s *LogSuite) TestLogWithContextFieldsAndError(c *C) { + const text = "My error message" + ctx := field.Context(context.Background(), "key", "value") + err := errors.New("test error") + entry := testLogMessage(c, text, WithError(err).WithContext(ctx).Print) + c.Assert(entry["level"], Equals, infoLevelStr) + // Error should be included in the log entry + c.Assert(entry["error"], Equals, err.Error()) + // A field with "key" should be set in the log entry + c.Assert(entry["key"], Equals, "value") +} + +func testLogMessage(c *C, msg string, print func(string)) map[string]interface{} { + log.SetFormatter(&logrus.JSONFormatter{TimestampFormat: time.RFC3339Nano}) + var memLog bytes.Buffer + log.SetOutput(&memLog) + print(msg) + var entry map[string]interface{} + err := json.Unmarshal(memLog.Bytes(), &entry) + c.Assert(err, IsNil) + c.Assert(entry, NotNil) + c.Assert(entry["msg"], Equals, msg) + return entry +}