Skip to content

Commit

Permalink
feat: improve recovery middleware tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tigerwill90 committed Jun 25, 2024
1 parent 525d4af commit fab647c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 46 deletions.
65 changes: 22 additions & 43 deletions internal/slogpretty/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ const (
initialBufferSize = 1024
)

var (
_ slog.Handler = (*LogHandler)(nil)
_ slog.Handler = (*NoopHandler)(nil)
)
var _ slog.Handler = (*LogHandler)(nil)

var logBufPool = sync.Pool{
New: func() any {
b := make([]byte, 0, initialBufferSize)
Expand All @@ -33,10 +31,10 @@ var logBufPool = sync.Pool{

var (
Handler = &LogHandler{
w: &lockedWriter{w: os.Stdout},
errW: &lockedWriter{w: os.Stderr},
lvl: slog.LevelDebug,
groupAttr: make([]GroupOrAttrs, 0),
We: &lockedWriter{w: os.Stderr},
Wo: &lockedWriter{w: os.Stdout},
Lvl: slog.LevelDebug,
Goa: make([]GroupOrAttrs, 0),
}
timeFormat = fmt.Sprintf("%s %s", time.DateOnly, time.TimeOnly)
)
Expand All @@ -54,14 +52,14 @@ type GroupOrAttrs struct {
}

type LogHandler struct {
w io.Writer
errW io.Writer
lvl slog.Leveler
groupAttr []GroupOrAttrs
We io.Writer
Wo io.Writer
Lvl slog.Leveler
Goa []GroupOrAttrs
}

func (h *LogHandler) Enabled(_ netcontext.Context, level slog.Level) bool {
return level >= h.lvl.Level()
return level >= h.Lvl.Level()
}

func (h *LogHandler) Handle(_ netcontext.Context, record slog.Record) error {
Expand Down Expand Up @@ -110,7 +108,7 @@ func (h *LogHandler) Handle(_ netcontext.Context, record slog.Record) error {
buf = append(buf, " | "...)

lastGroup := ""
for _, goa := range h.groupAttr {
for _, goa := range h.Goa {
switch {
case goa.group != "":
lastGroup += goa.group + "."
Expand Down Expand Up @@ -140,11 +138,11 @@ func (h *LogHandler) Handle(_ netcontext.Context, record slog.Record) error {
buf[len(buf)-1] = '\n'

if record.Level >= slog.LevelError {
if _, err := h.errW.Write(buf); err != nil {
if _, err := h.We.Write(buf); err != nil {
return fmt.Errorf("failed to write buffer: %w", err)
}
} else {
if _, err := h.w.Write(buf); err != nil {
if _, err := h.Wo.Write(buf); err != nil {
return fmt.Errorf("failed to write buffer: %w", err)
}
}
Expand All @@ -159,19 +157,19 @@ func (h *LogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
}

return &LogHandler{
w: h.w,
errW: h.errW,
lvl: h.lvl,
groupAttr: append(h.groupAttr, newAttrs...),
We: h.We,
Wo: h.Wo,
Lvl: h.Lvl,
Goa: append(h.Goa, newAttrs...),
}
}

func (h *LogHandler) WithGroup(name string) slog.Handler {
return &LogHandler{
w: h.w,
errW: h.errW,
lvl: h.lvl,
groupAttr: append(h.groupAttr, GroupOrAttrs{group: name}),
We: h.We,
Wo: h.Wo,
Lvl: h.Lvl,
Goa: append(h.Goa, GroupOrAttrs{group: name}),
}
}

Expand Down Expand Up @@ -259,22 +257,3 @@ func latencyColor(d time.Duration) string {

return ansi.FgRed
}

type NoopHandler struct {
}

func (n NoopHandler) Enabled(_ netcontext.Context, _ slog.Level) bool {
return true
}

func (n NoopHandler) Handle(_ netcontext.Context, _ slog.Record) error {
return nil
}

func (n NoopHandler) WithAttrs(_ []slog.Attr) slog.Handler {
return NoopHandler{}
}

func (n NoopHandler) WithGroup(_ string) slog.Handler {
return NoopHandler{}
}
25 changes: 22 additions & 3 deletions recovery_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package fox

import (
"bytes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tigerwill90/fox/internal/slogpretty"
"log/slog"
"net"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -40,7 +42,14 @@ func TestAbortHandler(t *testing.T) {
}

func TestRecoveryMiddleware(t *testing.T) {
m := CustomRecoveryWithLogHandler(slogpretty.NoopHandler{}, func(c Context, err any) {
woBuf := bytes.NewBuffer(nil)
weBuf := bytes.NewBuffer(nil)

m := CustomRecoveryWithLogHandler(&slogpretty.LogHandler{
We: weBuf,
Wo: woBuf,
Lvl: slog.LevelDebug,
}, func(c Context, err any) {
c.Writer().WriteHeader(http.StatusInternalServerError)
_, _ = c.Writer().Write([]byte(err.(string)))
})
Expand All @@ -59,17 +68,26 @@ func TestRecoveryMiddleware(t *testing.T) {
r.ServeHTTP(w, req)
require.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, errMsg, w.Body.String())
assert.Equal(t, woBuf.Len(), 0)
assert.NotEqual(t, weBuf.Len(), 0)
}

func TestRecoveryMiddlewareWithBrokenPipe(t *testing.T) {
woBuf := bytes.NewBuffer(nil)
weBuf := bytes.NewBuffer(nil)

expectMsgs := map[syscall.Errno]string{
syscall.EPIPE: "broken pipe",
syscall.ECONNRESET: "connection reset by peer",
}

for errno, expectMsg := range expectMsgs {
t.Run(expectMsg, func(t *testing.T) {
f := New(WithMiddleware(CustomRecoveryWithLogHandler(slogpretty.NoopHandler{}, func(c Context, err any) {
f := New(WithMiddleware(CustomRecoveryWithLogHandler(&slogpretty.LogHandler{
We: weBuf,
Wo: woBuf,
Lvl: slog.LevelDebug,
}, func(c Context, err any) {
if !connIsBroken(err) {
http.Error(c.Writer(), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
Expand All @@ -82,8 +100,9 @@ func TestRecoveryMiddlewareWithBrokenPipe(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
w := httptest.NewRecorder()
f.ServeHTTP(w, req)

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, woBuf.Len(), 0)
assert.NotEqual(t, weBuf.Len(), 0)
})
}
}

0 comments on commit fab647c

Please sign in to comment.