Skip to content

Commit

Permalink
feat: add server option for NotFoundHandler and MethodNotAllowedHandl…
Browse files Browse the repository at this point in the history
…er (#3131)

Co-authored-by: Miles Liu <milesliu@birentech.com>
  • Loading branch information
liubing0427 and Miles Liu authored Jan 5, 2024
1 parent 4cabcaa commit 21de240
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
16 changes: 14 additions & 2 deletions transport/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,18 @@ func PathPrefix(prefix string) ServerOption {
}
}

func NotFoundHandler(handler http.Handler) ServerOption {
return func(s *Server) {
s.router.NotFoundHandler = handler
}
}

func MethodNotAllowedHandler(handler http.Handler) ServerOption {
return func(s *Server) {
s.router.MethodNotAllowedHandler = handler
}
}

// Server is an HTTP server wrapper.
type Server struct {
*http.Server
Expand Down Expand Up @@ -177,12 +189,12 @@ func NewServer(opts ...ServerOption) *Server {
strictSlash: true,
router: mux.NewRouter(),
}
srv.router.NotFoundHandler = http.DefaultServeMux
srv.router.MethodNotAllowedHandler = http.DefaultServeMux
for _, o := range opts {
o(srv)
}
srv.router.StrictSlash(srv.strictSlash)
srv.router.NotFoundHandler = http.DefaultServeMux
srv.router.MethodNotAllowedHandler = http.DefaultServeMux
srv.router.Use(srv.filter())
srv.Server = &http.Server{
Handler: FilterChain(srv.filters...)(srv.router),
Expand Down
16 changes: 16 additions & 0 deletions transport/http/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,19 @@ func TestListener(t *testing.T) {
t.Errorf("expected not empty")
}
}

func TestNotFoundHandler(t *testing.T) {
mux := http.NewServeMux()
srv := NewServer(NotFoundHandler(mux))
if !reflect.DeepEqual(srv.router.NotFoundHandler, mux) {
t.Errorf("expected %v got %v", mux, srv.router.NotFoundHandler)
}
}

func TestMethodNotAllowedHandler(t *testing.T) {
mux := http.NewServeMux()
srv := NewServer(MethodNotAllowedHandler(mux))
if !reflect.DeepEqual(srv.router.MethodNotAllowedHandler, mux) {
t.Errorf("expected %v got %v", mux, srv.router.MethodNotAllowedHandler)
}
}

0 comments on commit 21de240

Please sign in to comment.