Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix slog source attribute #144

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions log/slog.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"log/slog"
"os"
"time"
"runtime"
"context"

"github.com/cockroachdb/errors"
"github.com/cockroachdb/errors/withstack"
Expand All @@ -18,6 +20,18 @@ type RelayLogger struct {
var relayLogger *RelayLogger

func InitLogger(logLevel, format, output string) error {
// output
switch output {
case "stdout":
return InitLoggerWithWriter(logLevel, format, os.Stdout)
case "stderr":
return InitLoggerWithWriter(logLevel, format, os.Stderr)
default:
return errors.New("invalid log output")
}
}

func InitLoggerWithWriter(logLevel, format string, writer io.Writer) error {
// level
var slogLevel slog.Level
if err := slogLevel.UnmarshalText([]byte(logLevel)); err != nil {
Expand All @@ -28,17 +42,6 @@ func InitLogger(logLevel, format, output string) error {
AddSource: true,
}

// output
var writer io.Writer
switch output {
case "stdout":
writer = os.Stdout
case "stderr":
writer = os.Stderr
default:
return errors.New("invalid log output")
}

var slogLogger *slog.Logger
// format
switch format {
Expand All @@ -63,17 +66,38 @@ func InitLogger(logLevel, format, output string) error {
return nil
}

func (rl *RelayLogger) Error(msg string, err error, otherArgs ...any) {
err = withstack.WithStackDepth(err, 1)
func (rl *RelayLogger) log(logLevel slog.Level, skipCallDepth int, msg string, args ...any) {
ctx := context.Background();
if !rl.Logger.Enabled(ctx, logLevel) {
return
}

var pcs [1]uintptr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to add the following?

if !l.Enabled(ctx, level) {
	return
}

Ref: https://github.com/golang/go/blob/go1.22.5/src/log/slog/logger.go#L242-L244

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. It is fixed.

runtime.Callers(2 + skipCallDepth, pcs[:]) // skip [Callers, this func, ...]

record := slog.NewRecord(time.Now(), logLevel, msg, pcs[0])
record.Add(args...)

// note that official log function also ignores Handle() error
_ = rl.Logger.Handler().Handle(ctx, record)
}

func (rl *RelayLogger) error(skipCallDepth int, msg string, err error, otherArgs ...any) {
err = withstack.WithStackDepth(err, 1 + skipCallDepth)
var args []any
args = append(args, "error", err)
args = append(args, "stack", fmt.Sprintf("%+v", err))
args = append(args, otherArgs...)
rl.Logger.Error(msg, args...)

rl.log(slog.LevelError, 1 + skipCallDepth, msg, args...)
}

func (rl *RelayLogger) Error(msg string, err error, otherArgs ...any) {
rl.error(1, msg, err, otherArgs...)
}

func (rl *RelayLogger) Fatal(msg string, err error, otherArgs ...any) {
rl.Error(msg, err, otherArgs...)
rl.error(1, msg, err, otherArgs...)
panic(msg)
}

Expand Down Expand Up @@ -190,5 +214,5 @@ func (rl *RelayLogger) WithModule(
func (rl *RelayLogger) TimeTrack(start time.Time, name string, otherArgs ...any) {
elapsed := time.Since(start)
allArgs := append([]any{"name", name, "elapsed", elapsed.Nanoseconds()}, otherArgs...)
rl.Logger.Info("time track", allArgs...)
rl.log(slog.LevelInfo, 1, "time track", allArgs...)
}
97 changes: 97 additions & 0 deletions log/slog_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package log

import (
"testing"

"fmt"
"log/slog"
"bytes"
"encoding/json"
"regexp"
)

type setupType struct {
logger *RelayLogger
buffer bytes.Buffer
}

func beforeEach(t *testing.T) *setupType {
var r setupType

err := InitLoggerWithWriter("info", "json", &r.buffer)
if err != nil {
t.Fatal(err)
}

r.logger = GetLogger()

return &r
}

type logType struct {
Time string
Level string
Source struct {
Function string
File string
Line int
}
Msg string
Stack string
Error string
}

func parseResult(setup *setupType, t *testing.T) (string, logType) {
raw := setup.buffer.String()
var parsed logType

err := json.Unmarshal(setup.buffer.Bytes(), &parsed)
if err != nil {
t.Fatalf("fail to parse log: %v: %s", err, raw)
}

return raw, parsed
}

func TestLogLevel(t *testing.T) {
setup := beforeEach(t)

setup.logger.log(slog.LevelDebug, 0, "test")
if 0 < setup.buffer.Len() {
t.Fatalf("debug log is output: %s", setup.buffer.String())
}
}

func TestLogLog(t *testing.T) {
setup := beforeEach(t)

setup.logger.log(slog.LevelInfo, 0, "test")
raw, r := parseResult(setup, t)

if r.Level != "INFO" {
t.Fatalf("mismatch level: %s", raw)
}

if m, err := regexp.MatchString(`/log.TestLogLog$`, r.Source.Function); err != nil || !m {
t.Fatalf("mismatch source.function: %v", raw)
}
}

func TestLogError(t *testing.T) {
setup := beforeEach(t)

setup.logger.Error("testerr", fmt.Errorf("dummy"))
raw, r := parseResult(setup, t)

if r.Level != "ERROR" {
t.Fatalf("mismatch level: %s", raw)
}

if m, err := regexp.MatchString(`/log.TestLogError$`, r.Source.Function); err != nil || !m {
t.Fatalf("mismatch source.function: %v", raw)
}

if r.Error != "dummy" {
t.Fatalf("mismatch level: %s", raw)
}
}
Loading