diff --git a/codecs.go b/codecs.go index 6326a5c..210beac 100644 --- a/codecs.go +++ b/codecs.go @@ -23,6 +23,9 @@ type Codec interface { MaxChunks() int // Name returns the name of the codec. Name() string + // ValidateChunkSize returns an error if this codec does not support + // chunkSize. Returns nil if chunkSize is supported. + ValidateChunkSize(chunkSize int) error } // codecs is a global map used for keeping track of registered codecs for testing and JSON unmarshalling diff --git a/extendeddatasquare.go b/extendeddatasquare.go index f46abdb..d83b5cc 100644 --- a/extendeddatasquare.go +++ b/extendeddatasquare.go @@ -44,7 +44,8 @@ func (eds *ExtendedDataSquare) UnmarshalJSON(b []byte) error { return nil } -// ComputeExtendedDataSquare computes the extended data square for some chunks of data. +// ComputeExtendedDataSquare computes the extended data square for some chunks +// of original data. func ComputeExtendedDataSquare( data [][]byte, codec Codec, @@ -55,6 +56,10 @@ func ComputeExtendedDataSquare( } chunkSize := getChunkSize(data) + err := codec.ValidateChunkSize(int(chunkSize)) + if err != nil { + return nil, err + } ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize)) if err != nil { return nil, err @@ -80,6 +85,10 @@ func ImportExtendedDataSquare( } chunkSize := getChunkSize(data) + err := codec.ValidateChunkSize(int(chunkSize)) + if err != nil { + return nil, err + } ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize)) if err != nil { return nil, err @@ -104,6 +113,10 @@ func NewExtendedDataSquare(codec Codec, treeCreatorFn TreeConstructorFn, edsWidt if err != nil { return nil, err } + err = codec.ValidateChunkSize(int(chunkSize)) + if err != nil { + return nil, err + } data := make([][]byte, edsWidth*edsWidth) dataSquare, err := newDataSquare(data, treeCreatorFn, chunkSize) diff --git a/extendeddatasquare_test.go b/extendeddatasquare_test.go index ddf351a..47fffc8 100644 --- a/extendeddatasquare_test.go +++ b/extendeddatasquare_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ShardSize = 64 @@ -37,10 +38,6 @@ func TestComputeExtendedDataSquare(t *testing.T) { testCases := []testCase{ { name: "1x1", - // NOTE: data must contain byte slices that are a multiple of 64 - // bytes. - // See https://github.com/catid/leopard/blob/22ddc7804998d31c8f1a2617ee720e063b1fa6cd/README.md?plain=1#L27 - // See https://github.com/klauspost/reedsolomon/blob/fd3e6910a7e457563469172968f456ad9b7696b6/README.md?plain=1#L403 data: [][]byte{ones}, want: [][][]byte{ {ones, ones}, @@ -69,6 +66,26 @@ func TestComputeExtendedDataSquare(t *testing.T) { assert.Equal(t, tc.want, result.squareRow) }) } + + t.Run("returns an error if chunkSize is not a multiple of 64", func(t *testing.T) { + chunk := bytes.Repeat([]byte{1}, 65) + _, err := ComputeExtendedDataSquare([][]byte{chunk}, NewLeoRSCodec(), NewDefaultTree) + assert.Error(t, err) + }) +} + +func TestImportExtendedDataSquare(t *testing.T) { + t.Run("is able to import an EDS", func(t *testing.T) { + eds := createExampleEds(t, ShardSize) + got, err := ImportExtendedDataSquare(eds.Flattened(), NewLeoRSCodec(), NewDefaultTree) + assert.NoError(t, err) + assert.Equal(t, eds.Flattened(), got.Flattened()) + }) + t.Run("returns an error if chunkSize is not a multiple of 64", func(t *testing.T) { + chunk := bytes.Repeat([]byte{1}, 65) + _, err := ImportExtendedDataSquare([][]byte{chunk}, NewLeoRSCodec(), NewDefaultTree) + assert.Error(t, err) + }) } func TestMarshalJSON(t *testing.T) { @@ -104,6 +121,13 @@ func TestNewExtendedDataSquare(t *testing.T) { _, err := NewExtendedDataSquare(NewLeoRSCodec(), NewDefaultTree, edsWidth, chunkSize) assert.Error(t, err) }) + t.Run("returns an error if chunkSize is not a multiple of 64", func(t *testing.T) { + edsWidth := uint(1) + chunkSize := uint(65) + + _, err := NewExtendedDataSquare(NewLeoRSCodec(), NewDefaultTree, edsWidth, chunkSize) + assert.Error(t, err) + }) t.Run("returns a 4x4 EDS", func(t *testing.T) { edsWidth := uint(4) chunkSize := uint(512) @@ -274,3 +298,18 @@ func genRandDS(width int, chunkSize int) [][]byte { } return ds } + +func createExampleEds(t *testing.T, chunkSize int) (eds *ExtendedDataSquare) { + ones := bytes.Repeat([]byte{1}, chunkSize) + twos := bytes.Repeat([]byte{2}, chunkSize) + threes := bytes.Repeat([]byte{3}, chunkSize) + fours := bytes.Repeat([]byte{4}, chunkSize) + ods := [][]byte{ + ones, twos, + threes, fours, + } + + eds, err := ComputeExtendedDataSquare(ods, NewLeoRSCodec(), NewDefaultTree) + require.NoError(t, err) + return eds +} diff --git a/leopard.go b/leopard.go index 2810d78..3c3d256 100644 --- a/leopard.go +++ b/leopard.go @@ -1,6 +1,7 @@ package rsmt2d import ( + "fmt" "sync" "github.com/klauspost/reedsolomon" @@ -82,6 +83,17 @@ func (l *LeoRSCodec) Name() string { return Leopard } +// ValidateChunkSize returns an error if this codec does not support +// chunkSize. Returns nil if chunkSize is supported. +func (l *LeoRSCodec) ValidateChunkSize(chunkSize int) error { + // See https://github.com/catid/leopard/blob/22ddc7804998d31c8f1a2617ee720e063b1fa6cd/README.md?plain=1#L27 + // See https://github.com/klauspost/reedsolomon/blob/fd3e6910a7e457563469172968f456ad9b7696b6/README.md?plain=1#L403 + if chunkSize%64 != 0 { + return fmt.Errorf("chunkSize %v must be a multiple of 64 bytes", chunkSize) + } + return nil +} + func NewLeoRSCodec() *LeoRSCodec { return &LeoRSCodec{} }