diff --git a/datasquare.go b/datasquare.go index 45df481..5461d53 100644 --- a/datasquare.go +++ b/datasquare.go @@ -26,6 +26,9 @@ type dataSquare struct { createTreeFn TreeConstructorFn } +// newDataSquare populates the data square from the supplied data and treeCreator. +// No root calculation is performed. +// data may have nil values. func newDataSquare(data [][]byte, treeCreator TreeConstructorFn, chunkSize uint) (*dataSquare, error) { width := int(math.Ceil(math.Sqrt(float64(len(data))))) if width*width != len(data) { diff --git a/extendeddatacrossword.go b/extendeddatacrossword.go index 2df4428..ade26e8 100644 --- a/extendeddatacrossword.go +++ b/extendeddatacrossword.go @@ -53,7 +53,8 @@ type ErrByzantineData struct { } func (e *ErrByzantineData) Error() string { - return fmt.Sprintf("byzantine %s: %d", e.Axis, e.Index) + return fmt.Sprintf( + "byzantine %s: %d", e.Axis, e.Index) } // Repair attempts to repair an incomplete extended data square (EDS). The @@ -142,7 +143,7 @@ func (eds *ExtendedDataSquare) solveCrosswordRow( shares[c] = vectorData[c] } - // Attempt rebuild + // Attempt rebuild the row rebuiltShares, isDecoded, err := eds.rebuildShares(shares) if err != nil { return false, false, err @@ -167,7 +168,7 @@ func (eds *ExtendedDataSquare) solveCrosswordRow( if col[r] != nil { continue // not newly completed } - if noMissingData(col, r) { // not completed + if noMissingData(col, r) { // completed err := eds.verifyAgainstColRoots(colRoots, uint(c), col, r, rebuiltShares[c]) if err != nil { var byzErr *ErrByzantineData @@ -240,7 +241,7 @@ func (eds *ExtendedDataSquare) solveCrosswordCol( if row[c] != nil { continue // not newly completed } - if noMissingData(row, c) { // not completed + if noMissingData(row, c) { // completed err := eds.verifyAgainstRowRoots(rowRoots, uint(r), row, c, rebuiltShares[r]) if err != nil { var byzErr *ErrByzantineData @@ -299,35 +300,46 @@ func (eds *ExtendedDataSquare) verifyAgainstRowRoots( root, err = eds.computeSharesRootWithRebuiltShare(oldShares, Row, r, rebuiltIndex, rebuiltShare) } if err != nil { - return err + // any error during the computation of the root is considered byzantine + // the shares are set to nil, as the caller will populate them + return &ErrByzantineData{Row, r, nil} } if !bytes.Equal(root, rowRoots[r]) { + // the shares are set to nil, as the caller will populate them return &ErrByzantineData{Row, r, nil} } return nil } +// verifyAgainstColRoots checks that the shares of column index `c` match their expected column root available in `colRoots`. +// `colRoots` is a slice of the expected roots of the columns of the `eds`. +// `shares` is a slice of the shares of the column index `c` of the `eds`. +// `rebuiltIndex` is the index of the share that was rebuilt, if any. +// `rebuiltShare` is the rebuilt share, if any. +// Returns a ErrByzantineData error if the computed root does not match the expected root or if the root computation fails. func (eds *ExtendedDataSquare) verifyAgainstColRoots( colRoots [][]byte, c uint, - oldShares [][]byte, + shares [][]byte, rebuiltIndex int, rebuiltShare []byte, ) error { var root []byte var err error if rebuiltIndex < 0 || rebuiltShare == nil { - root, err = eds.computeSharesRoot(oldShares, Col, c) + root, err = eds.computeSharesRoot(shares, Col, c) } else { - root, err = eds.computeSharesRootWithRebuiltShare(oldShares, Col, c, rebuiltIndex, rebuiltShare) + root, err = eds.computeSharesRootWithRebuiltShare(shares, Col, c, rebuiltIndex, rebuiltShare) } if err != nil { - return err + // the shares are set to nil, as the caller will populate them + return &ErrByzantineData{Col, c, nil} } if !bytes.Equal(root, colRoots[c]) { + // the shares are set to nil, as the caller will populate them return &ErrByzantineData{Col, c, nil} } @@ -353,10 +365,13 @@ func (eds *ExtendedDataSquare) preRepairSanityCheck( // ensure that the roots are equal rowRoot, err := eds.getRowRoot(i) if err != nil { - return err + // any error regarding the root calculation signifies an issue in the shares e.g., out of order shares + // therefore, it should be treated as byzantine data + return &ErrByzantineData{Row, i, eds.row(i)} } if !bytes.Equal(rowRoots[i], rowRoot) { - return fmt.Errorf("bad root input: row %d expected %v got %v", i, rowRoots[i], rowRoot) + // if the roots are not equal, then the data is byzantine + return &ErrByzantineData{Row, i, eds.row(i)} } return nil }) @@ -379,14 +394,18 @@ func (eds *ExtendedDataSquare) preRepairSanityCheck( // ensure that the roots are equal colRoot, err := eds.getColRoot(i) if err != nil { - return err + // any error regarding the root calculation signifies an issue in the shares e.g., out of order shares + // therefore, it should be treated as byzantine data + return &ErrByzantineData{Col, i, eds.col(i)} } if !bytes.Equal(colRoots[i], colRoot) { - return fmt.Errorf("bad root input: col %d expected %v got %v", i, colRoots[i], colRoot) + // if the roots are not equal, then the data is byzantine + return &ErrByzantineData{Col, i, eds.col(i)} } return nil }) errs.Go(func() error { + // check if we take the first half of the col and encode it, we get the second half parityShares, err := eds.codec.Encode(eds.colSlice(0, i, eds.originalDataWidth)) if err != nil { return err @@ -414,6 +433,7 @@ func noMissingData(input [][]byte, rebuiltIndex int) bool { return true } +// computeSharesRoot calculates the root of the shares for the specified axis (`i`th column or row). func (eds *ExtendedDataSquare) computeSharesRoot(shares [][]byte, axis Axis, i uint) ([]byte, error) { tree := eds.createTreeFn(axis, i) for _, d := range shares { @@ -425,6 +445,7 @@ func (eds *ExtendedDataSquare) computeSharesRoot(shares [][]byte, axis Axis, i u return tree.Root() } +// computeSharesRootWithRebuiltShare computes the root of the shares with the rebuilt share `rebuiltShare` at the specified index `rebuiltIndex`. func (eds *ExtendedDataSquare) computeSharesRootWithRebuiltShare(shares [][]byte, axis Axis, i uint, rebuiltIndex int, rebuiltShare []byte) ([]byte, error) { tree := eds.createTreeFn(axis, i) for _, d := range shares[:rebuiltIndex] { diff --git a/extendeddatacrossword_test.go b/extendeddatacrossword_test.go index fcd8a19..e26c5e9 100644 --- a/extendeddatacrossword_test.go +++ b/extendeddatacrossword_test.go @@ -7,6 +7,7 @@ import ( "math/rand" "testing" + "github.com/celestiaorg/nmt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -335,3 +336,141 @@ func createTestEds(codec Codec, shareSize int) *ExtendedDataSquare { return eds } + +func TestCorruptedEdsReturnsErrByzantineData_UnorderedShares(t *testing.T) { + shareSize := 512 + namespaceSize := 1 + one := bytes.Repeat([]byte{1}, shareSize) + two := bytes.Repeat([]byte{2}, shareSize) + three := bytes.Repeat([]byte{3}, shareSize) + sharesValue := []int{1, 2, 3, 4} + tests := []struct { + name string + coords [][]uint + values [][]byte + wantErr bool + corruptedAxis Axis + corruptedIndex uint + }{ + { + name: "no corruption", + wantErr: false, + }, + { + // disturbs the order of shares in the first row, erases the rest of the eds + name: "rows with unordered shares", + wantErr: true, // repair should error out during root construction + corruptedAxis: Row, + coords: [][]uint{ + {0, 0}, + {0, 1}, + {1, 0}, + {1, 1}, + {1, 2}, + {1, 3}, + {2, 0}, + {2, 1}, + {2, 2}, + {2, 3}, + {3, 0}, + {3, 1}, + {3, 2}, + {3, 3}, + }, + values: [][]byte{ + two, one, + nil, nil, nil, nil, + nil, nil, nil, nil, + nil, nil, nil, nil, + }, + corruptedIndex: 0, + }, + { + // disturbs the order of shares in the first column, erases the rest of the eds + name: "columns with unordered shares", + wantErr: true, // repair should error out during root construction + corruptedAxis: Col, + coords: [][]uint{ + {0, 0}, + {0, 1}, + {0, 2}, + {0, 3}, + {1, 0}, + {1, 1}, + {1, 2}, + {1, 3}, + {2, 1}, + {2, 2}, + {2, 3}, + {3, 1}, + {3, 2}, + {3, 3}, + }, + values: [][]byte{ + three, nil, nil, nil, + one, nil, nil, nil, + nil, nil, nil, + nil, nil, nil, + }, + corruptedIndex: 0, + }, + } + + for codecName, codec := range codecs { + t.Run(codecName, func(t *testing.T) { + // create a DA header + eds := createTestEdsWithNMT(t, codec, shareSize, namespaceSize, 1, 2, 3, 4) + assert.NotNil(t, eds) + dAHeaderRoots, err := eds.getRowRoots() + assert.NoError(t, err) + + dAHeaderCols, err := eds.getColRoots() + assert.NoError(t, err) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // create an eds with the given shares + corruptEds := createTestEdsWithNMT(t, codec, shareSize, namespaceSize, sharesValue...) + assert.NotNil(t, corruptEds) + // corrupt it by setting the values at the given coordinates + for i, coords := range test.coords { + x := coords[0] + y := coords[1] + corruptEds.setCell(x, y, test.values[i]) + } + + err = corruptEds.Repair(dAHeaderRoots, dAHeaderCols) + assert.Equal(t, err != nil, test.wantErr) + if test.wantErr { + var byzErr *ErrByzantineData + assert.ErrorAs(t, err, &byzErr) + errors.As(err, &byzErr) + assert.Equal(t, byzErr.Axis, test.corruptedAxis) + assert.Equal(t, byzErr.Index, test.corruptedIndex) + } + }) + } + }) + } +} + +// createTestEdsWithNMT creates an extended data square with the given shares and namespace size. +// Shares are placed in row-major order. +// The first namespaceSize bytes of each share are treated as its namespace. +// Roots of the extended data square are computed using namespace merkle trees. +func createTestEdsWithNMT(t *testing.T, codec Codec, shareSize, namespaceSize int, sharesValue ...int) *ExtendedDataSquare { + // the first namespaceSize bytes of each share are the namespace + assert.True(t, shareSize > namespaceSize) + + // create shares of shareSize bytes + shares := make([][]byte, len(sharesValue)) + for i, shareValue := range sharesValue { + shares[i] = bytes.Repeat([]byte{byte(shareValue)}, shareSize) + } + edsWidth := 4 // number of shares per row/column in the extended data square + odsWidth := edsWidth / 2 // number of shares per row/column in the original data square + + eds, err := ComputeExtendedDataSquare(shares, codec, newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize))) + require.NoError(t, err) + + return eds +} diff --git a/extendeddatasquare_test.go b/extendeddatasquare_test.go index 2b47ab3..a81c247 100644 --- a/extendeddatasquare_test.go +++ b/extendeddatasquare_test.go @@ -74,7 +74,7 @@ func TestComputeExtendedDataSquare(t *testing.T) { func TestImportExtendedDataSquare(t *testing.T) { t.Run("is able to import an EDS", func(t *testing.T) { - eds := createExampleEds(t, ShardSize) + eds := createExampleEds(t, shareSize) got, err := ImportExtendedDataSquare(eds.Flattened(), NewLeoRSCodec(), NewDefaultTree) assert.NoError(t, err) assert.Equal(t, eds.Flattened(), got.Flattened()) diff --git a/go.mod b/go.mod index 747ca8c..4d60a05 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,8 @@ go 1.20 require ( github.com/celestiaorg/merkletree v0.0.0-20210714075610-a84dc3ddbbe4 - github.com/stretchr/testify v1.7.0 + github.com/celestiaorg/nmt v0.17.0 + github.com/stretchr/testify v1.8.4 ) require ( @@ -12,6 +13,8 @@ require ( golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0 ) +require github.com/stretchr/objx v0.5.0 // indirect + require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/klauspost/cpuid/v2 v2.1.1 // indirect @@ -22,5 +25,5 @@ require ( golang.org/x/crypto v0.1.0 // indirect golang.org/x/sys v0.1.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect - gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bb92e06..17a94f1 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/celestiaorg/merkletree v0.0.0-20210714075610-a84dc3ddbbe4 h1:CJdIpo8n5MFP2MwK0gSRcOVlDlFdQJO1p+FqdxYzmvc= github.com/celestiaorg/merkletree v0.0.0-20210714075610-a84dc3ddbbe4/go.mod h1:fzuHnhzj1pUygGz+1ZkB3uQbEUL4htqCGJ4Qs2LwMZA= +github.com/celestiaorg/nmt v0.17.0 h1:/k8YLwJvuHgT/jQ435zXKaDX811+sYEMXL4B/vYdSLU= +github.com/celestiaorg/nmt v0.17.0/go.mod h1:ZndCeAR4l9lxm7W51ouoyTo1cxhtFgK+4DpEIkxRA3A= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -20,8 +22,16 @@ github.com/minio/sha256-simd v1.0.0/go.mod h1:OuYzVNI5vcoYIAmbIvHPl3N3jUzVedXbKy github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= gitlab.com/NebulousLabs/errors v0.0.0-20171229012116-7ead97ef90b8/go.mod h1:ZkMZ0dpQyWwlENaeZVBiQRjhMEZvk6VTXquzl3FOFP8= gitlab.com/NebulousLabs/errors v0.0.0-20200929122200-06c536cf6975 h1:L/ENs/Ar1bFzUeKx6m3XjlmBgIUlykX9dzvp5k9NGxc= gitlab.com/NebulousLabs/errors v0.0.0-20200929122200-06c536cf6975/go.mod h1:ZkMZ0dpQyWwlENaeZVBiQRjhMEZvk6VTXquzl3FOFP8= @@ -46,3 +56,5 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/nmtwrapper_test.go b/nmtwrapper_test.go new file mode 100644 index 0000000..a651fc2 --- /dev/null +++ b/nmtwrapper_test.go @@ -0,0 +1,142 @@ +package rsmt2d + +// The contents of this file have been adapted from the source file available at https://github.com/celestiaorg/celestia-app/blob/bab6c0d0befe677ab8c2f4b83561c08affc7203e/pkg/wrapper/nmt_wrapper.go, +// solely for the purpose of testing rsmt2d expected behavior when integrated with a NamespaceMerkleTree. +// Please note that this file has undergone several modifications and may not match the original file exactly. + +import ( + "bytes" + "fmt" + + "github.com/celestiaorg/nmt" + "github.com/celestiaorg/nmt/namespace" + "github.com/minio/sha256-simd" +) + +// Fulfills the Tree interface and TreeConstructorFn function +var ( + _ Tree = &erasuredNamespacedMerkleTree{} +) + +// erasuredNamespacedMerkleTree wraps NamespaceMerkleTree to conform to the +// Tree interface while also providing the correct namespaces to the +// underlying NamespaceMerkleTree. It does this by adding the already included +// namespace to the first half of the tree, and then uses the parity namespace +// ID for each share pushed to the second half of the tree. This allows for the +// namespaces to be included in the erasure data, while also keeping the nmt +// library sufficiently general +type erasuredNamespacedMerkleTree struct { + squareSize uint64 // note: this refers to the width of the original square before erasure-coded + options []nmt.Option + tree nmtTree + // axisIndex is the index of the axis (row or column) that this tree is on. This is passed + // by rsmt2d and used to help determine which quadrant each leaf belongs to. + axisIndex uint64 + // shareIndex is the index of the share in a row or column that is being + // pushed to the tree. It is expected to be in the range: 0 <= shareIndex < + // 2*squareSize. shareIndex is used to help determine which quadrant each + // leaf belongs to, along with keeping track of how many leaves have been + // added to the tree so far. + shareIndex uint64 + namespaceSize int +} + +// nmtTree is an interface that wraps the methods of the underlying +// NamespaceMerkleTree that are used by erasuredNamespacedMerkleTree. This +// interface is mainly used for testing. It is not recommended to use this +// interface by implementing a different implementation. +type nmtTree interface { + Root() ([]byte, error) + Push(namespacedData namespace.PrefixedData) error +} + +// newErasuredNamespacedMerkleTree creates a new erasuredNamespacedMerkleTree +// with an underlying NMT of namespace size `NamespaceSize` and with +// `ignoreMaxNamespace=true`. axisIndex is the index of the row or column that +// this tree is committing to. squareSize must be greater than zero. +func newErasuredNamespacedMerkleTree(squareSize uint64, axisIndex uint, options ...nmt.Option) erasuredNamespacedMerkleTree { + if squareSize == 0 { + panic("cannot create a erasuredNamespacedMerkleTree of squareSize == 0") + } + // read the options to extract the namespace size, and use it to construct erasuredNamespacedMerkleTree + opts := &nmt.Options{} + for _, setter := range options { + setter(opts) + } + options = append(options, nmt.IgnoreMaxNamespace(true)) + tree := nmt.New(sha256.New(), options...) + return erasuredNamespacedMerkleTree{squareSize: squareSize, namespaceSize: int(opts.NamespaceIDSize), options: options, tree: tree, axisIndex: uint64(axisIndex), shareIndex: 0} +} + +type constructor struct { + squareSize uint64 + opts []nmt.Option +} + +// newConstructor creates a tree constructor function as required by rsmt2d to +// calculate the data root. It creates that tree using the +// erasuredNamespacedMerkleTree. +func newConstructor(squareSize uint64, opts ...nmt.Option) TreeConstructorFn { + return constructor{ + squareSize: squareSize, + opts: opts, + }.NewTree +} + +// NewTree creates a new Tree using the +// erasuredNamespacedMerkleTree with predefined square size and +// nmt.Options +func (c constructor) NewTree(_ Axis, axisIndex uint) Tree { + newTree := newErasuredNamespacedMerkleTree(c.squareSize, axisIndex, c.opts...) + return &newTree +} + +// Push adds the provided data to the underlying NamespaceMerkleTree, and +// automatically uses the first erasuredNamespacedMerkleTree.namespaceSize number of bytes as the +// namespace unless the data pushed to the second half of the tree. Fulfills the +// rsmt2d.Tree interface. NOTE: panics if an error is encountered while pushing or +// if the tree size is exceeded. +func (w *erasuredNamespacedMerkleTree) Push(data []byte) error { + ParitySharesNamespaceBytes := bytes.Repeat([]byte{0xFF}, w.namespaceSize) + if w.axisIndex+1 > 2*w.squareSize || w.shareIndex+1 > 2*w.squareSize { + return fmt.Errorf("pushed past predetermined square size: boundary at %d index at %d %d", 2*w.squareSize, w.axisIndex, w.shareIndex) + } + if len(data) < w.namespaceSize { + return fmt.Errorf("data is too short to contain namespace ID") + } + nidAndData := make([]byte, w.namespaceSize+len(data)) + copy(nidAndData[w.namespaceSize:], data) + // use the parity namespace if the cell is not in Q0 of the extended data square + if w.isQuadrantZero() { + copy(nidAndData[:w.namespaceSize], data[:w.namespaceSize]) + } else { + copy(nidAndData[:w.namespaceSize], ParitySharesNamespaceBytes) + } + err := w.tree.Push(nidAndData) + if err != nil { + return err + } + w.incrementShareIndex() + return nil +} + +// Root fulfills the rsmt2d.Tree interface by generating and returning the +// underlying NamespaceMerkleTree Root. +func (w *erasuredNamespacedMerkleTree) Root() ([]byte, error) { + root, err := w.tree.Root() + if err != nil { + return nil, err + } + return root, nil +} + +// incrementShareIndex increments the share index by one. +func (w *erasuredNamespacedMerkleTree) incrementShareIndex() { + w.shareIndex++ +} + +// isQuadrantZero returns true if the current share index and axis index are both +// in the original data square. +func (w *erasuredNamespacedMerkleTree) isQuadrantZero() bool { + return w.shareIndex < w.squareSize && w.axisIndex < w.squareSize +}