diff --git a/datasquare_test.go b/datasquare_test.go index 42514f7..78331db 100644 --- a/datasquare_test.go +++ b/datasquare_test.go @@ -188,9 +188,9 @@ func TestInvalidSquareExtension(t *testing.T) { } } -// TestRoots verifies that the row roots and column roots are equal for a 1x1 +// Test_getRoots verifies that the row roots and column roots are equal for a 1x1 // square. -func TestRoots(t *testing.T) { +func Test_getRoots(t *testing.T) { result, err := newDataSquare([][]byte{{1, 2}}, NewDefaultTree, 2) assert.NoError(t, err) diff --git a/extendeddatasquare.go b/extendeddatasquare.go index 9193d87..f684020 100644 --- a/extendeddatasquare.go +++ b/extendeddatasquare.go @@ -322,6 +322,24 @@ func (eds *ExtendedDataSquare) Equals(other *ExtendedDataSquare) bool { return true } +// Roots returns a byte slice with this eds's RowRoots and ColRoots +// concatenated. +func (eds *ExtendedDataSquare) Roots() (roots [][]byte, err error) { + rowRoots, err := eds.RowRoots() + if err != nil { + return nil, err + } + colRoots, err := eds.ColRoots() + if err != nil { + return nil, err + } + + roots = make([][]byte, 0, len(rowRoots)+len(colRoots)) + roots = append(roots, rowRoots...) + roots = append(roots, colRoots...) + return roots, nil +} + // validateEdsWidth returns an error if edsWidth is not a valid width for an // extended data square. func validateEdsWidth(edsWidth uint) error { diff --git a/extendeddatasquare_test.go b/extendeddatasquare_test.go index 0815418..2cad1b2 100644 --- a/extendeddatasquare_test.go +++ b/extendeddatasquare_test.go @@ -426,6 +426,49 @@ func TestEquals(t *testing.T) { }) } +func TestRoots(t *testing.T) { + t.Run("returns roots for a 4x4 EDS", func(t *testing.T) { + eds, err := ComputeExtendedDataSquare([][]byte{ + ones, twos, + threes, fours, + }, NewLeoRSCodec(), NewDefaultTree) + require.NoError(t, err) + + roots, err := eds.Roots() + require.NoError(t, err) + assert.Len(t, roots, 8) + + rowRoots, err := eds.RowRoots() + require.NoError(t, err) + + colRoots, err := eds.ColRoots() + require.NoError(t, err) + + assert.Equal(t, roots[0], rowRoots[0]) + assert.Equal(t, roots[1], rowRoots[1]) + assert.Equal(t, roots[2], rowRoots[2]) + assert.Equal(t, roots[3], rowRoots[3]) + assert.Equal(t, roots[4], colRoots[0]) + assert.Equal(t, roots[5], colRoots[1]) + assert.Equal(t, roots[6], colRoots[2]) + assert.Equal(t, roots[7], colRoots[3]) + }) + + t.Run("returns an error for an incomplete EDS", func(t *testing.T) { + eds, err := ComputeExtendedDataSquare([][]byte{ + ones, twos, + threes, fours, + }, NewLeoRSCodec(), NewDefaultTree) + require.NoError(t, err) + + // set a cell to nil to make the EDS incomplete + eds.setCell(0, 0, nil) + + _, err = eds.Roots() + assert.Error(t, err) + }) +} + func createExampleEds(t *testing.T, chunkSize int) (eds *ExtendedDataSquare) { ones := bytes.Repeat([]byte{1}, chunkSize) twos := bytes.Repeat([]byte{2}, chunkSize)