From 2f80f631e7bb5a6659e76720d53200c6eb527159 Mon Sep 17 00:00:00 2001
From: Giuseppe Scrivano <gscrivan@redhat.com>
Date: Thu, 7 Nov 2024 15:18:07 +0100
Subject: [PATCH] chunked: refactor value into const

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
---
 pkg/chunked/compression_linux.go  | 14 ++++++++++----
 pkg/chunked/storage_linux_test.go | 17 +++++++++++++----
 2 files changed, 23 insertions(+), 8 deletions(-)

diff --git a/pkg/chunked/compression_linux.go b/pkg/chunked/compression_linux.go
index debbdacf33..d0c225d5ef 100644
--- a/pkg/chunked/compression_linux.go
+++ b/pkg/chunked/compression_linux.go
@@ -17,6 +17,12 @@ import (
 	expMaps "golang.org/x/exp/maps"
 )
 
+const (
+	// maxTocSize is the maximum size of a blob that we will attempt to process.
+	// It is used to prevent DoS attacks from layers that embed a very large TOC file.
+	maxTocSize = (1 << 20) * 50
+)
+
 var typesToTar = map[string]byte{
 	TypeReg:     tar.TypeReg,
 	TypeLink:    tar.TypeLink,
@@ -74,7 +80,7 @@ func readEstargzChunkedManifest(blobStream ImageSourceSeekable, blobSize int64,
 
 	size := int64(blobSize - footerSize - tocOffset)
 	// set a reasonable limit
-	if size > (1<<20)*50 {
+	if size > maxTocSize {
 		return nil, 0, errors.New("manifest too big")
 	}
 
@@ -103,7 +109,7 @@ func readEstargzChunkedManifest(blobStream ImageSourceSeekable, blobSize int64,
 					return err
 				}
 				// set a reasonable limit
-				if header.Size > (1<<20)*50 {
+				if header.Size > maxTocSize {
 					return errors.New("manifest too big")
 				}
 
@@ -163,10 +169,10 @@ func readZstdChunkedManifest(blobStream ImageSourceSeekable, tocDigest digest.Di
 	}
 
 	// set a reasonable limit
-	if manifestChunk.Length > (1<<20)*50 {
+	if manifestChunk.Length > maxTocSize {
 		return nil, nil, nil, 0, errors.New("manifest too big")
 	}
-	if manifestLengthUncompressed > (1<<20)*50 {
+	if manifestLengthUncompressed > maxTocSize {
 		return nil, nil, nil, 0, errors.New("manifest too big")
 	}
 
diff --git a/pkg/chunked/storage_linux_test.go b/pkg/chunked/storage_linux_test.go
index 4dad668420..ddbf2a5ba6 100644
--- a/pkg/chunked/storage_linux_test.go
+++ b/pkg/chunked/storage_linux_test.go
@@ -128,7 +128,11 @@ func TestGetBlobAtWithErrors(t *testing.T) {
 
 	is := &mockImageSource{streams: streams, errors: errorsC}
 
-	resultChan, err := getBlobAt(is)
+	chunks := []ImageSourceChunk{
+		{Offset: 0, Length: 1},
+		{Offset: 1, Length: 1},
+	}
+	resultChan, err := getBlobAt(is, chunks...)
 	require.NoError(t, err)
 
 	expectedErrors := []string{"error1", "error2"}
@@ -149,13 +153,18 @@ func TestGetBlobAtMixedStreamsAndErrors(t *testing.T) {
 	errorsC := make(chan error, 1)
 
 	streams <- mockReadCloserFromContent("stream1")
+	streams <- mockReadCloserFromContent("stream2")
 	errorsC <- errors.New("error1")
 	close(streams)
 	close(errorsC)
 
 	is := &mockImageSource{streams: streams, errors: errorsC}
 
-	resultChan, err := getBlobAt(is)
+	chunks := []ImageSourceChunk{
+		{Offset: 0, Length: 1},
+		{Offset: 1, Length: 1},
+	}
+	resultChan, err := getBlobAt(is, chunks...)
 	require.NoError(t, err)
 
 	var receivedStreams int
@@ -167,6 +176,6 @@ func TestGetBlobAtMixedStreamsAndErrors(t *testing.T) {
 			receivedStreams++
 		}
 	}
-	assert.Equal(t, 0, receivedStreams)
-	assert.Equal(t, 2, receivedErrors)
+	assert.Equal(t, 2, receivedStreams)
+	assert.Equal(t, 1, receivedErrors)
 }