diff --git a/cmd/rekor-server/app/root.go b/cmd/rekor-server/app/root.go index e258c2aa5..f1fbdd10d 100644 --- a/cmd/rekor-server/app/root.go +++ b/cmd/rekor-server/app/root.go @@ -105,6 +105,8 @@ Memory and file-based signers should only be used for testing.`) rootCmd.PersistentFlags().StringSlice("enabled_api_endpoints", operationIds, "list of API endpoints to enable using operationId from openapi.yaml") + rootCmd.PersistentFlags().Uint64("max_request_body_size", 0, "maximum size for HTTP request body, in bytes; set to 0 for unlimited") + if err := viper.BindPFlags(rootCmd.PersistentFlags()); err != nil { log.Logger.Fatal(err) } diff --git a/cmd/rekor-server/e2e_test.go b/cmd/rekor-server/e2e_test.go index b0002e059..4c74b1c1b 100644 --- a/cmd/rekor-server/e2e_test.go +++ b/cmd/rekor-server/e2e_test.go @@ -21,10 +21,13 @@ import ( "bufio" "bytes" "crypto/sha256" + "encoding/base64" "encoding/hex" "encoding/json" "fmt" + "io" "io/ioutil" + "math/rand" "net/http" "os" "path/filepath" @@ -358,3 +361,23 @@ func TestSearchQueryMalformedEntry(t *testing.T) { t.Fatalf("expected status 400, got %d instead", resp.StatusCode) } } + +func TestHTTPMaxRequestBodySize(t *testing.T) { + // default value is 32Mb so let's try to propose something bigger + pipeR, pipeW := io.Pipe() + go func() { + _, _ = io.CopyN(base64.NewEncoder(base64.StdEncoding, pipeW), rand.New(rand.NewSource(123)), 33*1024768) + pipeW.Close() + }() + // json parsing will hit first so we need to make sure this is valid JSON + bodyReader := io.MultiReader(strings.NewReader("{ \"key\": \""), pipeR, strings.NewReader("\"}")) + resp, err := http.Post(fmt.Sprintf("%s/api/v1/log/entries/retrieve", rekorServer()), + "application/json", + bodyReader) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusRequestEntityTooLarge { + t.Fatalf("expected status %d, got %d instead", http.StatusRequestEntityTooLarge, resp.StatusCode) + } +} diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 6a0b780c4..0721f7229 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -32,6 +32,7 @@ services: "--enable_attestation_storage", "--attestation_storage_bucket=file:///var/run/attestations", "--enable_killswitch", + "--max_request_body_size=32792576", ] ports: - "3000:3000" diff --git a/pkg/generated/restapi/configure_rekor_server.go b/pkg/generated/restapi/configure_rekor_server.go index 90cf05100..2041f0759 100644 --- a/pkg/generated/restapi/configure_rekor_server.go +++ b/pkg/generated/restapi/configure_rekor_server.go @@ -20,6 +20,7 @@ package restapi import ( "context" "crypto/tls" + go_errors "errors" "fmt" "net/http" "net/http/httputil" @@ -251,6 +252,10 @@ func (l *logFormatter) NewLogEntry(r *http.Request) middleware.LogEntry { // So this is a good place to plug in a panic handling middleware, logging and metrics func setupGlobalMiddleware(handler http.Handler) http.Handler { returnHandler := recoverer(handler) + maxReqBodySize := viper.GetInt64("max_request_body_size") + if maxReqBodySize > 0 { + returnHandler = maxBodySize(maxReqBodySize, returnHandler) + } middleware.DefaultLogger = middleware.RequestLogger(&logFormatter{}) returnHandler = middleware.Logger(returnHandler) returnHandler = middleware.Heartbeat("/ping")(returnHandler) @@ -339,8 +344,18 @@ func logAndServeError(w http.ResponseWriter, r *http.Request, err error) { } else { log.ContextLogger(ctx).Error(err) } + if compErr, ok := err.(*errors.CompositeError); ok { + // iterate over composite error looking for something more specific + for _, embeddedErr := range compErr.Errors { + var maxBytesError *http.MaxBytesError + if parseErr, ok := embeddedErr.(*errors.ParseError); ok && go_errors.As(parseErr.Reason, &maxBytesError) { + err = errors.New(http.StatusRequestEntityTooLarge, http.StatusText(http.StatusRequestEntityTooLarge)) + break + } + } + } requestFields := map[string]interface{}{} - if err := mapstructure.Decode(r, &requestFields); err == nil { + if decodeErr := mapstructure.Decode(r, &requestFields); decodeErr == nil { log.ContextLogger(ctx).Debug(requestFields) } errors.ServeError(w, r, err) @@ -386,3 +401,13 @@ func recoverer(next http.Handler) http.Handler { return http.HandlerFunc(fn) } + +// maxBodySize limits the request body +func maxBodySize(maxLength int64, next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLength) + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +}