diff --git a/log/slog.go b/log/slog.go index 866ae3c3..6e0ddc31 100644 --- a/log/slog.go +++ b/log/slog.go @@ -6,6 +6,8 @@ import ( "log/slog" "os" "time" + "runtime" + "context" "github.com/cockroachdb/errors" "github.com/cockroachdb/errors/withstack" @@ -18,6 +20,18 @@ type RelayLogger struct { var relayLogger *RelayLogger func InitLogger(logLevel, format, output string) error { + // output + switch output { + case "stdout": + return InitLoggerWithWriter(logLevel, format, os.Stdout) + case "stderr": + return InitLoggerWithWriter(logLevel, format, os.Stderr) + default: + return errors.New("invalid log output") + } +} + +func InitLoggerWithWriter(logLevel, format string, writer io.Writer) error { // level var slogLevel slog.Level if err := slogLevel.UnmarshalText([]byte(logLevel)); err != nil { @@ -28,17 +42,6 @@ func InitLogger(logLevel, format, output string) error { AddSource: true, } - // output - var writer io.Writer - switch output { - case "stdout": - writer = os.Stdout - case "stderr": - writer = os.Stderr - default: - return errors.New("invalid log output") - } - var slogLogger *slog.Logger // format switch format { @@ -63,17 +66,38 @@ func InitLogger(logLevel, format, output string) error { return nil } -func (rl *RelayLogger) Error(msg string, err error, otherArgs ...any) { - err = withstack.WithStackDepth(err, 1) +func (rl *RelayLogger) log(logLevel slog.Level, skipCallDepth int, msg string, args ...any) { + ctx := context.Background(); + if !rl.Logger.Enabled(ctx, logLevel) { + return + } + + var pcs [1]uintptr + runtime.Callers(2 + skipCallDepth, pcs[:]) // skip [Callers, this func, ...] + + record := slog.NewRecord(time.Now(), logLevel, msg, pcs[0]) + record.Add(args...) + + // note that official log function also ignores Handle() error + _ = rl.Logger.Handler().Handle(ctx, record) +} + +func (rl *RelayLogger) error(skipCallDepth int, msg string, err error, otherArgs ...any) { + err = withstack.WithStackDepth(err, 1 + skipCallDepth) var args []any args = append(args, "error", err) args = append(args, "stack", fmt.Sprintf("%+v", err)) args = append(args, otherArgs...) - rl.Logger.Error(msg, args...) + + rl.log(slog.LevelError, 1 + skipCallDepth, msg, args...) +} + +func (rl *RelayLogger) Error(msg string, err error, otherArgs ...any) { + rl.error(1, msg, err, otherArgs...) } func (rl *RelayLogger) Fatal(msg string, err error, otherArgs ...any) { - rl.Error(msg, err, otherArgs...) + rl.error(1, msg, err, otherArgs...) panic(msg) } @@ -190,5 +214,5 @@ func (rl *RelayLogger) WithModule( func (rl *RelayLogger) TimeTrack(start time.Time, name string, otherArgs ...any) { elapsed := time.Since(start) allArgs := append([]any{"name", name, "elapsed", elapsed.Nanoseconds()}, otherArgs...) - rl.Logger.Info("time track", allArgs...) + rl.log(slog.LevelInfo, 1, "time track", allArgs...) } diff --git a/log/slog_test.go b/log/slog_test.go new file mode 100644 index 00000000..1b457d4e --- /dev/null +++ b/log/slog_test.go @@ -0,0 +1,97 @@ +package log + +import ( + "testing" + + "fmt" + "log/slog" + "bytes" + "encoding/json" + "regexp" +) + +type setupType struct { + logger *RelayLogger + buffer bytes.Buffer +} + +func beforeEach(t *testing.T) *setupType { + var r setupType + + err := InitLoggerWithWriter("info", "json", &r.buffer) + if err != nil { + t.Fatal(err) + } + + r.logger = GetLogger() + + return &r +} + +type logType struct { + Time string + Level string + Source struct { + Function string + File string + Line int + } + Msg string + Stack string + Error string +} + +func parseResult(setup *setupType, t *testing.T) (string, logType) { + raw := setup.buffer.String() + var parsed logType + + err := json.Unmarshal(setup.buffer.Bytes(), &parsed) + if err != nil { + t.Fatalf("fail to parse log: %v: %s", err, raw) + } + + return raw, parsed +} + +func TestLogLevel(t *testing.T) { + setup := beforeEach(t) + + setup.logger.log(slog.LevelDebug, 0, "test") + if 0 < setup.buffer.Len() { + t.Fatalf("debug log is output: %s", setup.buffer.String()) + } +} + +func TestLogLog(t *testing.T) { + setup := beforeEach(t) + + setup.logger.log(slog.LevelInfo, 0, "test") + raw, r := parseResult(setup, t) + + if r.Level != "INFO" { + t.Fatalf("mismatch level: %s", raw) + } + + if m, err := regexp.MatchString(`/log.TestLogLog$`, r.Source.Function); err != nil || !m { + t.Fatalf("mismatch source.function: %v", raw) + } +} + +func TestLogError(t *testing.T) { + setup := beforeEach(t) + + setup.logger.Error("testerr", fmt.Errorf("dummy")) + raw, r := parseResult(setup, t) + + if r.Level != "ERROR" { + t.Fatalf("mismatch level: %s", raw) + } + + if m, err := regexp.MatchString(`/log.TestLogError$`, r.Source.Function); err != nil || !m { + t.Fatalf("mismatch source.function: %v", raw) + } + + if r.Error != "dummy" { + t.Fatalf("mismatch level: %s", raw) + } +}