diff --git a/share/eds/byzantine/bad_encoding.go b/share/eds/byzantine/bad_encoding.go index 0672096b25..ae77026acf 100644 --- a/share/eds/byzantine/bad_encoding.go +++ b/share/eds/byzantine/bad_encoding.go @@ -2,7 +2,6 @@ package byzantine import ( "bytes" - "errors" "fmt" "github.com/celestiaorg/celestia-app/pkg/wrapper" @@ -113,50 +112,58 @@ func (p *BadEncodingProof) UnmarshalBinary(data []byte) error { func (p *BadEncodingProof) Validate(hdr libhead.Header) error { header, ok := hdr.(*header.ExtendedHeader) if !ok { - panic(fmt.Sprintf("invalid header type: expected %T, got %T", header, hdr)) + panic(fmt.Sprintf("invalid header type received during BEFP validation: expected %T, got %T", header, hdr)) } if header.Height() != int64(p.BlockHeight) { - return errors.New("fraud: incorrect block height") + return fmt.Errorf("incorrect block height during BEFP validation: expected %d, got %d", + p.BlockHeight, header.Height(), + ) } - merkleRowRoots := header.DAH.RowRoots - merkleColRoots := header.DAH.ColumnRoots - if len(merkleRowRoots) != len(merkleColRoots) { + + if len(header.DAH.RowRoots) != len(header.DAH.ColumnRoots) { // NOTE: This should never happen as callers of this method should not feed it with a // malformed extended header. panic(fmt.Sprintf( - "fraud: invalid extended header: length of row and column roots do not match. (rowRoots=%d) (colRoots=%d)", - len(merkleRowRoots), - len(merkleColRoots)), + "invalid extended header: length of row and column roots do not match. (rowRoots=%d) (colRoots=%d)", + len(header.DAH.RowRoots), + len(header.DAH.ColumnRoots)), ) } - if int(p.Index) >= len(merkleRowRoots) { - return fmt.Errorf("fraud: invalid proof: index out of bounds (%d >= %d)", int(p.Index), len(merkleRowRoots)) + + // merkleRoots are the roots against which we are going to check the inclusion of the received + // shares. Changing the order of the roots to prove the shares relative to the orthogonal axis, + // because inside the rsmt2d library rsmt2d.Row = 0 and rsmt2d.Col = 1 + merkleRoots := header.DAH.RowRoots + if p.Axis == rsmt2d.Row { + merkleRoots = header.DAH.ColumnRoots } - if len(merkleRowRoots) != len(p.Shares) { - return fmt.Errorf("fraud: invalid proof: incorrect number of shares %d != %d", len(p.Shares), len(merkleRowRoots)) + if int(p.Index) >= len(merkleRoots) { + return fmt.Errorf("invalid %s proof: index out of bounds (%d >= %d)", + BadEncoding, int(p.Index), len(merkleRoots), + ) } - - root := merkleRowRoots[p.Index] - if p.Axis == rsmt2d.Col { - root = merkleColRoots[p.Index] + if len(p.Shares) != len(merkleRoots) { + return fmt.Errorf("invalid %s proof: incorrect number of shares %d != %d", + BadEncoding, len(p.Shares), len(merkleRoots), + ) } // verify that Merkle proofs correspond to particular shares. - shares := make([][]byte, len(merkleRowRoots)) + shares := make([][]byte, len(merkleRoots)) for index, shr := range p.Shares { if shr == nil { continue } // validate inclusion of the share into one of the DAHeader roots - if ok := shr.Validate(ipld.MustCidFromNamespacedSha256(root)); !ok { - return fmt.Errorf("fraud: invalid proof: incorrect share received at index %d", index) + if ok := shr.Validate(ipld.MustCidFromNamespacedSha256(merkleRoots[index])); !ok { + return fmt.Errorf("invalid %s proof: incorrect share received at index %d", BadEncoding, index) } // NMTree commits the additional namespace while rsmt2d does not know about, so we trim it // this is ugliness from NMTWrapper that we have to embrace ¯\_(ツ)_/¯ shares[index] = share.GetData(shr.Share) } - odsWidth := uint64(len(merkleRowRoots) / 2) + odsWidth := uint64(len(merkleRoots) / 2) codec := share.DefaultRSMT2DCodec() // rebuild a row or col. @@ -183,10 +190,15 @@ func (p *BadEncodingProof) Validate(hdr libhead.Header) error { return err } + // root is a merkle root of the row/col where ErrByzantine occurred + root := header.DAH.RowRoots[p.Index] + if p.Axis == rsmt2d.Col { + root = header.DAH.ColumnRoots[p.Index] + } + // comparing rebuilt Merkle Root of bad row/col with respective Merkle Root of row/col from block. if bytes.Equal(expectedRoot, root) { - return errors.New("fraud: invalid proof: recomputed Merkle root matches the DAH's row/column root") + return fmt.Errorf("invalid %s proof: recomputed Merkle root matches the DAH's row/column root", BadEncoding) } - return nil } diff --git a/share/eds/byzantine/byzantine.go b/share/eds/byzantine/byzantine.go index b9c8ef414f..0fcd78273e 100644 --- a/share/eds/byzantine/byzantine.go +++ b/share/eds/byzantine/byzantine.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/ipfs/go-blockservice" + "golang.org/x/sync/errgroup" "github.com/celestiaorg/celestia-app/pkg/da" "github.com/celestiaorg/rsmt2d" @@ -35,17 +36,41 @@ func NewErrByzantine( dah *da.DataAvailabilityHeader, errByz *rsmt2d.ErrByzantineData, ) *ErrByzantine { - root := [][][]byte{ - dah.RowRoots, + // changing the order to collect proofs against an orthogonal axis + roots := [][][]byte{ dah.ColumnRoots, - }[errByz.Axis][errByz.Index] - sharesWithProof, err := GetProofsForShares( - ctx, - bGetter, - ipld.MustCidFromNamespacedSha256(root), - errByz.Shares, - ) - if err != nil { + dah.RowRoots, + }[errByz.Axis] + + sharesWithProof := make([]*ShareWithProof, len(errByz.Shares)) + sharesAmount := 0 + + errGr, ctx := errgroup.WithContext(ctx) + for index, share := range errByz.Shares { + // skip further shares if we already requested half of them, which is enough to recompute the row + // or col + if sharesAmount == len(dah.RowRoots)/2 { + break + } + + if share == nil { + continue + } + sharesAmount++ + + index := index + errGr.Go(func() error { + share, err := getProofsAt( + ctx, bGetter, + ipld.MustCidFromNamespacedSha256(roots[index]), + int(errByz.Index), len(errByz.Shares), + ) + sharesWithProof[index] = share + return err + }) + } + + if err := errGr.Wait(); err != nil { // Fatal as rsmt2d proved that error is byzantine, // but we cannot properly collect the proof, // so verification will fail and thus services won't be stopped @@ -53,7 +78,6 @@ func NewErrByzantine( // TODO(@Wondertan): Find a better way to handle log.Fatalw("getting proof for ErrByzantine", "err", err) } - return &ErrByzantine{ Index: uint32(errByz.Index), Shares: sharesWithProof, diff --git a/share/eds/byzantine/share_proof.go b/share/eds/byzantine/share_proof.go index b8e39ee1d3..d6aa9dad51 100644 --- a/share/eds/byzantine/share_proof.go +++ b/share/eds/byzantine/share_proof.go @@ -78,24 +78,38 @@ func GetProofsForShares( proofs := make([]*ShareWithProof, len(shares)) for index, share := range shares { if share != nil { - proof := make([]cid.Cid, 0) - // TODO(@vgonkivs): Combine GetLeafData and GetProof in one function as the are traversing the same - // tree. Add options that will control what data will be fetched. - s, err := ipld.GetLeaf(ctx, bGetter, root, index, len(shares)) + proof, err := getProofsAt(ctx, bGetter, root, index, len(shares)) if err != nil { return nil, err } - proof, err = ipld.GetProof(ctx, bGetter, root, proof, index, len(shares)) - if err != nil { - return nil, err - } - proofs[index] = NewShareWithProof(index, s.RawData(), proof) + proofs[index] = proof } } - return proofs, nil } +func getProofsAt( + ctx context.Context, + bGetter blockservice.BlockGetter, + root cid.Cid, + index, + total int, +) (*ShareWithProof, error) { + proof := make([]cid.Cid, 0) + // TODO(@vgonkivs): Combine GetLeafData and GetProof in one function as the are traversing the same + // tree. Add options that will control what data will be fetched. + node, err := ipld.GetLeaf(ctx, bGetter, root, index, total) + if err != nil { + return nil, err + } + + proof, err = ipld.GetProof(ctx, bGetter, root, proof, index, total) + if err != nil { + return nil, err + } + return NewShareWithProof(index, node.RawData(), proof), nil +} + func ProtoToShare(protoShares []*pb.Share) []*ShareWithProof { shares := make([]*ShareWithProof, len(protoShares)) for i, share := range protoShares {