Skip to content

Commit

Permalink
revert!: rsmt2d#277 and rsmt2d#287 (#295)
Browse files Browse the repository at this point in the history
While attempting to bump celestia-app to the v0.12.0-rc2, I noticed that
the `RegisterTree` design leaks an implementation detail to
celestia-app: the registering and managing of `treeName`s. Celestia-app
has two categories of of trees:
1. erasured namespaced merkle tree in
[nmt_wrapper.go](https://github.com/celestiaorg/celestia-app/blob/main/pkg/wrapper/nmt_wrapper.go)
2. EDS subtree root cacher
[nmt_caching.go](https://github.com/celestiaorg/celestia-app/blob/main/pkg/inclusion/nmt_caching.go)

Each of those categories has trees based on square size and NMT options.
Celestia-app needs to be careful to register all the appropriate trees
once (and only once) before they are used (via `Compute` or `Import`).
I'd like to explore a less breaking option to get celestia-node the
original desired feature which was
celestiaorg/rsmt2d#275. In the meantime, I
think we should revert the two big breaking changes so that main can
remain release-able.

Revert celestiaorg/rsmt2d#277
Revert celestiaorg/rsmt2d#287
Closes celestiaorg/rsmt2d#295 because no longer
relevant if we merge this.
  • Loading branch information
rootulp authored Feb 7, 2024
1 parent 322b817 commit 51f8909
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 365 deletions.
48 changes: 0 additions & 48 deletions default_tree.go

This file was deleted.

20 changes: 8 additions & 12 deletions extendeddatacrossword_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestRepairExtendedDataSquare(t *testing.T) {
flattened[12], flattened[13] = nil, nil

// Re-import the data square.
eds, err := ImportExtendedDataSquare(flattened, codec, DefaultTreeName)
eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
Expand All @@ -67,7 +67,7 @@ func TestRepairExtendedDataSquare(t *testing.T) {
flattened[12], flattened[13], flattened[14] = nil, nil, nil

// Re-import the data square.
eds, err := ImportExtendedDataSquare(flattened, codec, DefaultTreeName)
eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
Expand Down Expand Up @@ -237,7 +237,7 @@ func BenchmarkRepair(b *testing.B) {

// Generate a new range original data square then extend it
square := genRandDS(originalDataWidth, shareSize)
eds, err := ComputeExtendedDataSquare(square, codec, DefaultTreeName)
eds, err := ComputeExtendedDataSquare(square, codec, NewDefaultTree)
if err != nil {
b.Error(err)
}
Expand Down Expand Up @@ -275,7 +275,7 @@ func BenchmarkRepair(b *testing.B) {
}

// Re-import the data square.
eds, _ = ImportExtendedDataSquare(flattened, codec, DefaultTreeName)
eds, _ = ImportExtendedDataSquare(flattened, codec, NewDefaultTree)

b.StartTimer()

Expand All @@ -301,7 +301,7 @@ func createTestEds(codec Codec, shareSize int) *ExtendedDataSquare {
eds, err := ComputeExtendedDataSquare([][]byte{
ones, twos,
threes, fours,
}, codec, DefaultTreeName)
}, codec, NewDefaultTree)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -390,14 +390,8 @@ func TestCorruptedEdsReturnsErrByzantineData_UnorderedShares(t *testing.T) {

codec := NewLeoRSCodec()

edsWidth := 4 // number of shares per row/column in the extended data square
odsWidth := edsWidth / 2 // number of shares per row/column in the original data square
err := RegisterTree("testing-tree", newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize)))
assert.NoError(t, err)

// create a DA header
eds := createTestEdsWithNMT(t, codec, shareSize, namespaceSize, 1, 2, 3, 4)

assert.NotNil(t, eds)
dAHeaderRoots, err := eds.getRowRoots()
assert.NoError(t, err)
Expand Down Expand Up @@ -442,8 +436,10 @@ func createTestEdsWithNMT(t *testing.T, codec Codec, shareSize, namespaceSize in
for i, shareValue := range sharesValue {
shares[i] = bytes.Repeat([]byte{byte(shareValue)}, shareSize)
}
edsWidth := 4 // number of shares per row/column in the extended data square
odsWidth := edsWidth / 2 // number of shares per row/column in the original data square

eds, err := ComputeExtendedDataSquare(shares, codec, "testing-tree")
eds, err := ComputeExtendedDataSquare(shares, codec, newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize)))
require.NoError(t, err)

return eds
Expand Down
36 changes: 7 additions & 29 deletions extendeddatasquare.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,29 @@ import (
type ExtendedDataSquare struct {
*dataSquare
codec Codec
treeName string
originalDataWidth uint
}

func (eds *ExtendedDataSquare) MarshalJSON() ([]byte, error) {
return json.Marshal(&struct {
DataSquare [][]byte `json:"data_square"`
Codec string `json:"codec"`
Tree string `json:"tree"`
}{
DataSquare: eds.dataSquare.Flattened(),
Codec: eds.codec.Name(),
Tree: eds.treeName,
})
}

func (eds *ExtendedDataSquare) UnmarshalJSON(b []byte) error {
var aux struct {
DataSquare [][]byte `json:"data_square"`
Codec string `json:"codec"`
Tree string `json:"tree"`
}

err := json.Unmarshal(b, &aux)
if err != nil {
if err := json.Unmarshal(b, &aux); err != nil {
return err
}

if aux.Tree == "" {
aux.Tree = DefaultTreeName
}

importedEds, err := ImportExtendedDataSquare(aux.DataSquare, codecs[aux.Codec], aux.Tree)
importedEds, err := ImportExtendedDataSquare(aux.DataSquare, codecs[aux.Codec], NewDefaultTree)
if err != nil {
return err
}
Expand All @@ -60,7 +50,7 @@ func (eds *ExtendedDataSquare) UnmarshalJSON(b []byte) error {
func ComputeExtendedDataSquare(
data [][]byte,
codec Codec,
treeName string,
treeCreatorFn TreeConstructorFn,
) (*ExtendedDataSquare, error) {
if len(data) > codec.MaxChunks() {
return nil, errors.New("number of chunks exceeds the maximum")
Expand All @@ -71,18 +61,12 @@ func ComputeExtendedDataSquare(
if err != nil {
return nil, err
}

treeCreatorFn, err := TreeFn(treeName)
if err != nil {
return nil, err
}

ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize))
if err != nil {
return nil, err
}

eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: treeName}
eds := ExtendedDataSquare{dataSquare: ds, codec: codec}
err = eds.erasureExtendSquare(codec)
if err != nil {
return nil, err
Expand All @@ -95,7 +79,7 @@ func ComputeExtendedDataSquare(
func ImportExtendedDataSquare(
data [][]byte,
codec Codec,
treeName string,
treeCreatorFn TreeConstructorFn,
) (*ExtendedDataSquare, error) {
if len(data) > 4*codec.MaxChunks() {
return nil, errors.New("number of chunks exceeds the maximum")
Expand All @@ -106,18 +90,12 @@ func ImportExtendedDataSquare(
if err != nil {
return nil, err
}

treeCreatorFn, err := TreeFn(treeName)
if err != nil {
return nil, err
}

ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize))
if err != nil {
return nil, err
}

eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: treeName}
eds := ExtendedDataSquare{dataSquare: ds, codec: codec}
err = validateEdsWidth(eds.width)
if err != nil {
return nil, err
Expand Down Expand Up @@ -248,7 +226,7 @@ func (eds *ExtendedDataSquare) erasureExtendCol(codec Codec, i uint) error {
}

func (eds *ExtendedDataSquare) deepCopy(codec Codec) (ExtendedDataSquare, error) {
imported, err := ImportExtendedDataSquare(eds.Flattened(), codec, eds.treeName)
imported, err := ImportExtendedDataSquare(eds.Flattened(), codec, eds.createTreeFn)
return *imported, err
}

Expand Down
Loading

0 comments on commit 51f8909

Please sign in to comment.