diff --git a/config/config.go b/config/config.go index 3b69233..06d2b36 100644 --- a/config/config.go +++ b/config/config.go @@ -48,9 +48,11 @@ type IPFSConfig struct { } type APIConfig struct { - ListenAddr string `mapstructure:"listen-addr"` - WriteTimeout int64 `mapstructure:"write-timeout"` - ReadTimeout int64 `mapstructure:"read-timeout"` + ListenAddr string `mapstructure:"listen-addr"` + WriteTimeout int64 `mapstructure:"write-timeout"` + ReadTimeout int64 `mapstructure:"read-timeout"` + MaxConnections int `mapstructure:"max-connections"` + MaxRequestBodySize int64 `mapstructure:"max-request-body-size"` } func DefaultConfig() *Config { @@ -81,9 +83,11 @@ func DefaultConfig() *Config { IPFSNodeAddr: "127.0.0.1:5001", }, API: APIConfig{ - ListenAddr: "127.0.0.1:8080", - WriteTimeout: 60, - ReadTimeout: 15, + ListenAddr: "127.0.0.1:8080", + WriteTimeout: 60, + ReadTimeout: 15, + MaxConnections: 50, + MaxRequestBodySize: 4 << (10 * 2), // 4MB }, } } diff --git a/config/toml.go b/config/toml.go index 91eb0c8..9f87ecb 100644 --- a/config/toml.go +++ b/config/toml.go @@ -70,6 +70,8 @@ ipfs-node-addr = "{{ .IPFS.IPFSNodeAddr }}" listen-addr = "{{ .API.ListenAddr }}" write-timeout = "{{ .API.WriteTimeout }}" read-timeout = "{{ .API.ReadTimeout }}" +max-connections = "{{ .API.MaxConnections }}" +max-request-body-size = "{{ .API.MaxRequestBodySize }}" ` var configTemplate *template.Template diff --git a/server/middleware/limit.go b/server/middleware/limit.go new file mode 100644 index 0000000..69e8728 --- /dev/null +++ b/server/middleware/limit.go @@ -0,0 +1,31 @@ +package middleware + +import ( + "net/http" +) + +type limitMiddleware struct { + maxRequestBodySize int64 +} + +func NewLimitMiddleware(maxRequestBodySize int64) *limitMiddleware { + return &limitMiddleware{ + maxRequestBodySize, + } +} + +// Middleware limits the request body size. +// This is done by first constraining to the ContentLength of the request headder, +// and then reading the actual Body to constraint it. +func (mw *limitMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ContentLength > mw.maxRequestBodySize { + http.Error(w, "request body too large", http.StatusBadRequest) + return + } + r.Body = http.MaxBytesReader(w, r.Body, mw.maxRequestBodySize) + defer r.Body.Close() + + next.ServeHTTP(w, r) + }) +} diff --git a/server/middleware/limit_test.go b/server/middleware/limit_test.go new file mode 100644 index 0000000..cbac012 --- /dev/null +++ b/server/middleware/limit_test.go @@ -0,0 +1,102 @@ +package middleware_test + +import ( + "bytes" + "crypto/rand" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/medibloc/panacea-oracle/server/middleware" + "github.com/stretchr/testify/require" +) + +func TestBodySizeSmallerThanLimitSetting(t *testing.T) { + testLimitMiddlewareHTTPRequest( + t, + newRequest(newRandomBody(1023)), + 1024, + http.StatusOK, + "", + ) +} + +func TestBodySizeSameLimitSetting(t *testing.T) { + testLimitMiddlewareHTTPRequest( + t, + newRequest(newRandomBody(1024)), + 1024, + http.StatusOK, + "", + ) +} + +func TestBodySizeLargeThanLimitSetting(t *testing.T) { + testLimitMiddlewareHTTPRequest( + t, + newRequest(newRandomBody(1025)), + 1024, + http.StatusBadRequest, + "request body too large", + ) +} + +func TestDifferentBodySizeAndHeaderContentSize(t *testing.T) { + req := newRequest(newRandomBody(1025)) + req.ContentLength = 1024 + + testLimitMiddlewareHTTPRequest( + t, + req, + 1024, + http.StatusBadRequest, + "request body too large", + ) +} + +func newRandomBody(size int) []byte { + body := make([]byte, size) + if _, err := rand.Read(body); err != nil { + panic(err) + } + + return body +} + +func newRequest(body []byte) *http.Request { + return httptest.NewRequest( + "GET", + "http://test.com", + bytes.NewBuffer(body), + ) +} + +func testLimitMiddlewareHTTPRequest( + t *testing.T, + req *http.Request, + maxRequestBodySize int64, + statusCode int, + bodyMsg string, +) { + w := httptest.NewRecorder() + mw := middleware.NewLimitMiddleware(maxRequestBodySize) + testHandler := mw.Middleware( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } + }), + ) + + testHandler.ServeHTTP(w, req) + + resp := w.Result() + require.Equal(t, statusCode, resp.StatusCode) + if bodyMsg != "" { + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), bodyMsg) + } +} diff --git a/server/netutil_test.go b/server/netutil_test.go new file mode 100644 index 0000000..8a21444 --- /dev/null +++ b/server/netutil_test.go @@ -0,0 +1,90 @@ +package server_test + +import ( + "io" + "net" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/net/netutil" +) + +func TestNetUtil(t *testing.T) { + + lis := &fakeListener{timeWait: 1} + + limitLis := netutil.LimitListener(lis, 2) + + wg := &sync.WaitGroup{} + start := time.Now() + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + + conn, err := limitLis.Accept() + require.NoError(t, err) + defer conn.Close() + }(i) + } + + wg.Wait() + end := time.Now() + + // Send 10 requests, process 2 at a time, and take 1 second per request. + // This request test should take 5 to 6 seconds. + require.True(t, start.Add(time.Second*5).Before(end)) + require.True(t, start.Add(time.Second*6).After(end)) + +} + +type fakeListener struct { + timeWait int64 +} + +// Accept waits for and returns the next connection to the listener. +func (l *fakeListener) Accept() (net.Conn, error) { + time.Sleep(time.Second * time.Duration(l.timeWait)) + + return fakeNetConn{}, nil +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (l *fakeListener) Close() error { + return nil +} + +// Addr returns the listener's network address. +func (l *fakeListener) Addr() net.Addr { + return fakeAddr(1) +} + +type fakeNetConn struct { + io.Reader + io.Writer +} + +func (c fakeNetConn) Close() error { return nil } +func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } +func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } +func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type fakeAddr int + +var ( + localAddr = fakeAddr(1) + remoteAddr = fakeAddr(2) +) + +func (a fakeAddr) Network() string { + return "net" +} + +func (a fakeAddr) String() string { + return "str" +} diff --git a/server/server.go b/server/server.go index 047268a..0a33f8c 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "net" "net/http" "time" @@ -11,19 +12,24 @@ import ( "github.com/medibloc/panacea-oracle/server/service/status" "github.com/medibloc/panacea-oracle/service" log "github.com/sirupsen/logrus" + "golang.org/x/net/netutil" ) type Server struct { *http.Server + maxConnections int } func New(svc service.Service) *Server { router := mux.NewRouter() + conf := svc.Config() + limitMiddleware := middleware.NewLimitMiddleware(conf.API.MaxRequestBodySize) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(svc.QueryClient()) dealRouter := router.PathPrefix("/v0/data-deal").Subrouter() dealRouter.Use( + limitMiddleware.Middleware, jwtAuthMiddleware.Middleware, ) @@ -35,14 +41,24 @@ func New(svc service.Service) *Server { return &Server{ &http.Server{ Handler: router, - Addr: svc.Config().API.ListenAddr, - WriteTimeout: time.Duration(svc.Config().API.WriteTimeout) * time.Second, - ReadTimeout: time.Duration(svc.Config().API.ReadTimeout) * time.Second, + Addr: conf.API.ListenAddr, + WriteTimeout: time.Duration(conf.API.WriteTimeout) * time.Second, + ReadTimeout: time.Duration(conf.API.ReadTimeout) * time.Second, }, + conf.API.MaxConnections, } } func (srv *Server) Run() error { + addr := srv.Addr + if addr == "" { + addr = ":http" + } + lis, err := net.Listen("tcp", addr) + if err != nil { + return err + } + log.Infof("HTTP server is started: %s", srv.Addr) - return srv.ListenAndServe() + return srv.Serve(netutil.LimitListener(lis, srv.maxConnections)) }