diff --git a/extendeddatacrossword_test.go b/extendeddatacrossword_test.go index 4a77363..370b36e 100644 --- a/extendeddatacrossword_test.go +++ b/extendeddatacrossword_test.go @@ -168,41 +168,52 @@ func TestCannotRepairSquareWithBadRoots(t *testing.T) { } func TestCorruptedEdsReturnsErrByzantineData(t *testing.T) { - bufferSize := 64 - corruptChunk := bytes.Repeat([]byte{66}, bufferSize) + shareSize := 64 + corruptChunk := bytes.Repeat([]byte{66}, shareSize) tests := []struct { - name string - // Size of each share, in bytes - shareSize int - cells [][]byte - values [][]byte + name string + coords [][]uint + values [][]byte }{ - {"BadRow/OriginalData", bufferSize, [][]byte{{0, 0}}, [][]byte{corruptChunk}}, - {"BadRow/ExtendedData", bufferSize, [][]byte{{0, 3}}, [][]byte{corruptChunk}}, - {"BadColumn/OriginalData", bufferSize, [][]byte{{0, 0}, {0, 1}, {0, 2}, {0, 3}}, [][]byte{corruptChunk, nil, nil, nil}}, - {"BadColumn/OriginalData", bufferSize, [][]byte{{3, 0}, {0, 1}, {0, 2}, {0, 3}}, [][]byte{corruptChunk, nil, nil, nil}}, + { + name: "corrupt a chunk in the original data square", + coords: [][]uint{{0, 0}}, + values: [][]byte{corruptChunk}, + }, + { + name: "corrupt a chunk in the extended data square", + coords: [][]uint{{0, 3}}, + values: [][]byte{corruptChunk}, + }, + { + name: "corrupt a chunk at (0, 0) and delete shares from the rest of the row", + coords: [][]uint{{0, 0}, {0, 1}, {0, 2}, {0, 3}}, + values: [][]byte{corruptChunk, nil, nil, nil}, + }, + { + name: "corrupt a chunk at (3, 0) and delete part of the first row ", + coords: [][]uint{{3, 0}, {0, 1}, {0, 2}, {0, 3}}, + values: [][]byte{corruptChunk, nil, nil, nil}, + }, } for codecName, codec := range codecs { t.Run(codecName, func(t *testing.T) { - original := createTestEds(codec, bufferSize) - - var byzData *ErrByzantineData for _, test := range tests { t.Run(test.name, func(t *testing.T) { - corrupted, err := original.deepCopy(codec) - if err != nil { - t.Fatalf("unexpected err while copying original data: %v, codec: :%s", err, codecName) - } - for i := 0; i < len(test.cells); i++ { - corrupted.setCell(uint(test.cells[i][0]), uint(test.cells[i][1]), test.values[i]) - } - err = corrupted.Repair(corrupted.getRowRoots(), corrupted.getColRoots()) - if !errors.As(err, &byzData) { - // due to parallelisation, the ErrByzantineData axis may be either row or col - t.Errorf("did not return a ErrByzantineData for a bad col or row; got %v", err) + eds := createTestEds(codec, shareSize) + for i, coords := range test.coords { + x := coords[0] + y := coords[1] + eds.setCell(x, y, test.values[i]) } + err := eds.Repair(eds.getRowRoots(), eds.getColRoots()) + assert.Error(t, err) + + // due to parallelisation, the ErrByzantineData axis may be either row or col + var byzData *ErrByzantineData + assert.ErrorAs(t, err, &byzData, "did not return a ErrByzantineData for a bad col or row") }) } }) @@ -273,11 +284,11 @@ func BenchmarkRepair(b *testing.B) { } } -func createTestEds(codec Codec, bufferSize int) *ExtendedDataSquare { - ones := bytes.Repeat([]byte{1}, bufferSize) - twos := bytes.Repeat([]byte{2}, bufferSize) - threes := bytes.Repeat([]byte{3}, bufferSize) - fours := bytes.Repeat([]byte{4}, bufferSize) +func createTestEds(codec Codec, shareSize int) *ExtendedDataSquare { + ones := bytes.Repeat([]byte{1}, shareSize) + twos := bytes.Repeat([]byte{2}, shareSize) + threes := bytes.Repeat([]byte{3}, shareSize) + fours := bytes.Repeat([]byte{4}, shareSize) eds, err := ComputeExtendedDataSquare([][]byte{ ones, twos,