diff --git a/rsmt2d_test.go b/rsmt2d_test.go index d81e733..ff0ba5a 100644 --- a/rsmt2d_test.go +++ b/rsmt2d_test.go @@ -7,6 +7,7 @@ import ( "github.com/celestiaorg/rsmt2d" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestEdsRepairRoundtripSimple(t *testing.T) { @@ -157,3 +158,57 @@ func TestEdsRepairTwice(t *testing.T) { }) } } + +// TestRepairWithOneQuarterPopulated is motivated by a use case from +// celestia-node. It verifies that a new EDS can be populated via SetCell. After +// enough chunks have been populated, it verifies that the EDS can be repaired. +// After the EDS is repaired, the test verifies that data in a repaired cell +// matches the expected data. +func TestRepairWithOneQuarterPopulated(t *testing.T) { + edsWidth := 4 + chunkSize := 512 + + exampleEds := createExampleEds(t, chunkSize) + + eds, err := rsmt2d.NewExtendedDataSquare(rsmt2d.NewLeoRSCodec(), rsmt2d.NewDefaultTree, uint(edsWidth), uint(chunkSize)) + require.NoError(t, err) + + // Populate EDS with 1/4 of chunks using SetCell + err = eds.SetCell(0, 0, exampleEds.GetCell(0, 0)) + require.NoError(t, err) + err = eds.SetCell(0, 1, exampleEds.GetCell(0, 1)) + require.NoError(t, err) + err = eds.SetCell(1, 0, exampleEds.GetCell(1, 0)) + require.NoError(t, err) + err = eds.SetCell(1, 1, exampleEds.GetCell(1, 1)) + require.NoError(t, err) + + // Verify that an unpopulated cell returns nil + assert.Nil(t, eds.GetCell(3, 3)) + + rowRoots, err := exampleEds.RowRoots() + require.NoError(t, err) + colRoots, err := exampleEds.ColRoots() + require.NoError(t, err) + + // Repair the EDS + err = eds.Repair(rowRoots, colRoots) + assert.NoError(t, err) + + assert.Equal(t, exampleEds.Flattened(), eds.Flattened()) +} + +func createExampleEds(t *testing.T, chunkSize int) (eds *rsmt2d.ExtendedDataSquare) { + ones := bytes.Repeat([]byte{1}, chunkSize) + twos := bytes.Repeat([]byte{2}, chunkSize) + threes := bytes.Repeat([]byte{3}, chunkSize) + fours := bytes.Repeat([]byte{4}, chunkSize) + ods := [][]byte{ + ones, twos, + threes, fours, + } + + eds, err := rsmt2d.ComputeExtendedDataSquare(ods, rsmt2d.NewLeoRSCodec(), rsmt2d.NewDefaultTree) + require.NoError(t, err) + return eds +}