diff --git a/default_tree.go b/default_tree.go new file mode 100644 index 0000000..deb9224 --- /dev/null +++ b/default_tree.go @@ -0,0 +1,48 @@ +package rsmt2d + +import ( + "crypto/sha256" + "fmt" + + "github.com/celestiaorg/merkletree" +) + +var DefaultTreeName = "default-tree" + +func init() { + err := RegisterTree(DefaultTreeName, NewDefaultTree) + if err != nil { + panic(fmt.Sprintf("%s already registered", DefaultTreeName)) + } +} + +var _ Tree = &DefaultTree{} + +type DefaultTree struct { + *merkletree.Tree + leaves [][]byte + root []byte +} + +func NewDefaultTree(_ Axis, _ uint) Tree { + return &DefaultTree{ + Tree: merkletree.New(sha256.New()), + leaves: make([][]byte, 0, 128), + } +} + +func (d *DefaultTree) Push(data []byte) error { + // ignore the idx, as this implementation doesn't need that info + d.leaves = append(d.leaves, data) + return nil +} + +func (d *DefaultTree) Root() ([]byte, error) { + if d.root == nil { + for _, l := range d.leaves { + d.Tree.Push(l) + } + d.root = d.Tree.Root() + } + return d.root, nil +} diff --git a/extendeddatacrossword_test.go b/extendeddatacrossword_test.go index 6266b79..5b15465 100644 --- a/extendeddatacrossword_test.go +++ b/extendeddatacrossword_test.go @@ -390,8 +390,14 @@ func TestCorruptedEdsReturnsErrByzantineData_UnorderedShares(t *testing.T) { codec := NewLeoRSCodec() + edsWidth := 4 // number of shares per row/column in the extended data square + odsWidth := edsWidth / 2 // number of shares per row/column in the original data square + err := RegisterTree("testing-tree", newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize))) + assert.NoError(t, err) + // create a DA header eds := createTestEdsWithNMT(t, codec, shareSize, namespaceSize, 1, 2, 3, 4) + assert.NotNil(t, eds) dAHeaderRoots, err := eds.getRowRoots() assert.NoError(t, err) @@ -436,10 +442,11 @@ func createTestEdsWithNMT(t *testing.T, codec Codec, shareSize, namespaceSize in for i, shareValue := range sharesValue { shares[i] = bytes.Repeat([]byte{byte(shareValue)}, shareSize) } - edsWidth := 4 // number of shares per row/column in the extended data square - odsWidth := edsWidth / 2 // number of shares per row/column in the original data square - eds, err := ComputeExtendedDataSquare(shares, codec, newConstructor(uint64(odsWidth), nmt.NamespaceIDSize(namespaceSize))) + treeConstructorFn, err := TreeFn("testing-tree") + require.NoError(t, err) + + eds, err := ComputeExtendedDataSquare(shares, codec, treeConstructorFn) require.NoError(t, err) return eds diff --git a/extendeddatasquare.go b/extendeddatasquare.go index d076dd1..ec58197 100644 --- a/extendeddatasquare.go +++ b/extendeddatasquare.go @@ -15,6 +15,7 @@ import ( type ExtendedDataSquare struct { *dataSquare codec Codec + treeName string originalDataWidth uint } @@ -22,9 +23,11 @@ func (eds *ExtendedDataSquare) MarshalJSON() ([]byte, error) { return json.Marshal(&struct { DataSquare [][]byte `json:"data_square"` Codec string `json:"codec"` + Tree string `json:"tree"` }{ DataSquare: eds.dataSquare.Flattened(), Codec: eds.codec.Name(), + Tree: eds.treeName, }) } @@ -32,12 +35,25 @@ func (eds *ExtendedDataSquare) UnmarshalJSON(b []byte) error { var aux struct { DataSquare [][]byte `json:"data_square"` Codec string `json:"codec"` + Tree string `json:"tree"` } - if err := json.Unmarshal(b, &aux); err != nil { + err := json.Unmarshal(b, &aux) + if err != nil { + return err + } + + var treeConstructor TreeConstructorFn + if aux.Tree == "" { + aux.Tree = DefaultTreeName + } + + treeConstructor, err = TreeFn(aux.Tree) + if err != nil { return err } - importedEds, err := ImportExtendedDataSquare(aux.DataSquare, codecs[aux.Codec], NewDefaultTree) + + importedEds, err := ImportExtendedDataSquare(aux.DataSquare, codecs[aux.Codec], treeConstructor) if err != nil { return err } @@ -61,12 +77,18 @@ func ComputeExtendedDataSquare( if err != nil { return nil, err } + ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize)) if err != nil { return nil, err } - eds := ExtendedDataSquare{dataSquare: ds, codec: codec} + treeName := getTreeNameFromConstructorFn(treeCreatorFn) + if treeName == "" { + return nil, errors.New("tree name not found") + } + + eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: treeName} err = eds.erasureExtendSquare(codec) if err != nil { return nil, err @@ -90,12 +112,18 @@ func ImportExtendedDataSquare( if err != nil { return nil, err } + ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize)) if err != nil { return nil, err } - eds := ExtendedDataSquare{dataSquare: ds, codec: codec} + treeName := getTreeNameFromConstructorFn(treeCreatorFn) + if treeName == "" { + return nil, errors.New("tree name not found") + } + + eds := ExtendedDataSquare{dataSquare: ds, codec: codec, treeName: treeName} err = validateEdsWidth(eds.width) if err != nil { return nil, err diff --git a/extendeddatasquare_test.go b/extendeddatasquare_test.go index 057e78f..041aac4 100644 --- a/extendeddatasquare_test.go +++ b/extendeddatasquare_test.go @@ -111,6 +111,69 @@ func TestMarshalJSON(t *testing.T) { } } +// TestUnmarshalJSON test the UnmarshalJSON function. +func TestUnmarshalJSON(t *testing.T) { + treeName := "testing_unmarshalJSON_tree" + treeConstructorFn := sudoConstructorFn + err := RegisterTree(treeName, treeConstructorFn) + require.NoError(t, err) + + codec := NewLeoRSCodec() + result, err := ComputeExtendedDataSquare([][]byte{ + ones, twos, + threes, fours, + }, codec, treeConstructorFn) + if err != nil { + panic(err) + } + + tests := []struct { + name string + malleate func() + expectedTreeName string + cleanUp func() + }{ + { + "Tree field exists", + func() {}, + treeName, + func() { + cleanUp(treeName) + }, + }, + { + "Tree field missing", + func() { + // clear the tree name value in the eds before marshal + result.treeName = "" + }, + DefaultTreeName, + func() {}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.malleate() + edsBytes, err := json.Marshal(result) + if err != nil { + t.Errorf("failed to marshal EDS: %v", err) + } + + var eds ExtendedDataSquare + err = json.Unmarshal(edsBytes, &eds) + if err != nil { + t.Errorf("failed to unmarshal EDS: %v", err) + } + if !reflect.DeepEqual(result.squareRow, eds.squareRow) { + t.Errorf("eds not equal after json marshal/unmarshal") + } + require.Equal(t, test.expectedTreeName, eds.treeName) + + test.cleanUp() + }) + } +} + func TestNewExtendedDataSquare(t *testing.T) { t.Run("returns an error if edsWidth is not even", func(t *testing.T) { edsWidth := uint(1) diff --git a/tree.go b/tree.go index f8dcc66..4523469 100644 --- a/tree.go +++ b/tree.go @@ -1,9 +1,9 @@ package rsmt2d import ( - "crypto/sha256" - - "github.com/celestiaorg/merkletree" + "fmt" + "reflect" + "sync" ) // TreeConstructorFn creates a fresh Tree instance to be used as the Merkle tree @@ -22,33 +22,65 @@ type Tree interface { Root() ([]byte, error) } -var _ Tree = &DefaultTree{} +// treeFns is a global map used for keeping track of registered tree constructors for JSON serialization +// The keys of this map should be kebab cased. E.g. "default-tree" +var treeFns = sync.Map{} + +// RegisterTree must be called in the init function +func RegisterTree(treeName string, treeConstructor TreeConstructorFn) error { + if _, ok := treeFns.Load(treeName); ok { + return fmt.Errorf("%s already registered", treeName) + } + + treeFns.Store(treeName, treeConstructor) -type DefaultTree struct { - *merkletree.Tree - leaves [][]byte - root []byte + return nil } -func NewDefaultTree(_ Axis, _ uint) Tree { - return &DefaultTree{ - Tree: merkletree.New(sha256.New()), - leaves: make([][]byte, 0, 128), +// TreeFn get tree constructor function by tree name from the global map registry +func TreeFn(treeName string) (TreeConstructorFn, error) { + var treeFn TreeConstructorFn + v, ok := treeFns.Load(treeName) + if !ok { + return nil, fmt.Errorf("%s not registered yet", treeName) + } + treeFn, ok = v.(TreeConstructorFn) + if !ok { + return nil, fmt.Errorf("key %s has invalid interface", treeName) } + + return treeFn, nil } -func (d *DefaultTree) Push(data []byte) error { - // ignore the idx, as this implementation doesn't need that info - d.leaves = append(d.leaves, data) - return nil +// removeTreeFn removes a treeConstructorFn by treeName. +// Only use for test cleanup. Proceed with caution. +func removeTreeFn(treeName string) { + treeFns.Delete(treeName) } -func (d *DefaultTree) Root() ([]byte, error) { - if d.root == nil { - for _, l := range d.leaves { - d.Tree.Push(l) +// Get the tree name by the tree constructor function from the global map registry +// TODO: this code is temporary until all breaking changes is handle here: https://github.com/celestiaorg/rsmt2d/pull/278 +func getTreeNameFromConstructorFn(treeConstructor TreeConstructorFn) string { + key := "" + treeFns.Range(func(k, v interface{}) bool { + keyString, ok := k.(string) + if !ok { + // continue checking other key, value + return true } - d.root = d.Tree.Root() - } - return d.root, nil + treeFn, ok := v.(TreeConstructorFn) + if !ok { + // continue checking other key, value + return true + } + + if reflect.DeepEqual(reflect.ValueOf(treeFn), reflect.ValueOf(treeConstructor)) { + key = keyString + return false + } + + return true + }) + + return key } diff --git a/tree_test.go b/tree_test.go new file mode 100644 index 0000000..b5fadc3 --- /dev/null +++ b/tree_test.go @@ -0,0 +1,202 @@ +package rsmt2d + +import ( + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestRegisterTree tests the RegisterTree function for adding +// a tree constructor function for a given tree name into treeFns +// global map. +func TestRegisterTree(t *testing.T) { + treeName := "testing_register_tree" + treeConstructorFn := sudoConstructorFn + + tests := []struct { + name string + expectErr error + }{ + // The tree has not been registered yet in the treeFns global map. + {"register successfully", nil}, + // The tree has already been registered in the treeFns global map. + {"register unsuccessfully", fmt.Errorf("%s already registered", treeName)}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := RegisterTree(treeName, treeConstructorFn) + if test.expectErr != nil { + require.Equal(t, test.expectErr, err) + } + + treeFn, err := TreeFn(treeName) + require.NoError(t, err) + assert.True(t, reflect.DeepEqual(reflect.ValueOf(treeFn), reflect.ValueOf(treeConstructorFn))) + }) + } + + cleanUp(treeName) +} + +// TestTreeFn test the TreeFn function which fetches the +// tree constructor function from the treeFns golbal map. +func TestTreeFn(t *testing.T) { + treeName := "testing_treeFn_tree" + treeConstructorFn := sudoConstructorFn + invalidCaseTreeName := "testing_invalid_register_tree" + invalidTreeConstructorFn := "invalid constructor fn" + + tests := []struct { + name string + treeName string + malleate func() + expectErr error + }{ + // The tree constructor function is successfully fetched + // from the global map. + { + "get successfully", + treeName, + func() { + err := RegisterTree(treeName, treeConstructorFn) + require.NoError(t, err) + }, + nil, + }, + // Unable to fetch the tree constructor function for an + // unregisted tree name. + { + "get unregisted tree name", + "unregistered_tree", + func() {}, + fmt.Errorf("%s not registered yet", "unregistered_tree"), + }, + // Value returned from the global map is an invalid value that + // cannot be type asserted into TreeConstructorFn type. + { + "get invalid interface value", + invalidCaseTreeName, + func() { + // Seems like this case has low probability of happening + // since all register has been done through RegisterTree func + // which have strict type check as argument. + treeFns.Store(invalidCaseTreeName, invalidTreeConstructorFn) + }, + fmt.Errorf("key %s has invalid interface", invalidCaseTreeName), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.malleate() + + treeFn, err := TreeFn(test.treeName) + if test.expectErr != nil { + require.Equal(t, test.expectErr, err) + } else { + require.NoError(t, err) + require.True(t, reflect.DeepEqual(reflect.ValueOf(treeFn), reflect.ValueOf(treeConstructorFn))) + } + }) + + cleanUp(test.treeName) + } +} + +// TestGetTreeNameFromConstructorFn tests the GetTreeNameFromConstructorFn +// function which fetches tree name by it corresponding tree constructor function. +// +// TODO: When we handle all the breaking changes track in this PR: https://github.com/celestiaorg/rsmt2d/pull/278, should remove this test +func TestGetTreeNameFromConstructorFn(t *testing.T) { + treeName := "testing_get_tree_name_tree" + treeConstructorFn := sudoConstructorFn + invalidTreeName := struct{}{} + invalidCaseTreeName := "invalid_case_tree" + invalidTreeConstructorFn := "invalid constructor fn" + + tests := []struct { + name string + treeName string + treeFn TreeConstructorFn + malleate func() + expectGetKey bool + }{ + // The tree name is successfully fetched. + { + "get successfully", + treeName, + treeConstructorFn, + func() { + err := RegisterTree(treeName, treeConstructorFn) + require.NoError(t, err) + }, + true, + }, + // Unable to fetch an unregistered tree name. + { + "get unregisted tree name", + "unregisted_tree_name", + nil, + func() {}, + false, + }, + // Value (tree constructor function) from the global map iteration is an invalid + // value that cannot be type asserted into TreeConstructorFn type. + { + "get invalid interface value", + "", + nil, + func() { + // Seems like this case has low probability of happening + // since all register has been done through RegisterTree func + // which have strict type check as argument. + treeFns.Store(invalidCaseTreeName, invalidTreeConstructorFn) + }, + false, + }, + // Key (tree name) from the global map iteration is an invalid value that cannot + // be type asserted into string type. + { + "get invalid interface key", + "", + nil, + func() { + // Seems like this case has low probability of happening + // since all register has been done through RegisterTree func + // which have strict type check as argument. + treeFns.Store(invalidTreeName, treeConstructorFn) + }, + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.malleate() + + key := getTreeNameFromConstructorFn(test.treeFn) + if !test.expectGetKey { + require.Equal(t, key, "") + } else { + require.Equal(t, test.treeName, key) + } + }) + + cleanUp(test.treeName) + } +} + +// Avoid duplicate with default_tree treeConstructorFn +// registered during init. +func sudoConstructorFn(_ Axis, _ uint) Tree { + return &DefaultTree{} +} + +// Clear tested tree constructor function in the global map. +func cleanUp(treeName string) { + removeTreeFn(treeName) +}