diff --git a/asserter/block.go b/asserter/block.go index fc951eae7..2b8895522 100644 --- a/asserter/block.go +++ b/asserter/block.go @@ -370,6 +370,10 @@ func (a *Asserter) Transaction( // any of the related transactions contain invalid types, invalid network identifiers, // invalid transaction identifiers, or a direction not defined by the enum. func (a *Asserter) RelatedTransactions(relatedTransactions []*types.RelatedTransaction) error { + if dup := DuplicateRelatedTransaction(relatedTransactions); dup != nil { + return fmt.Errorf("%w: %v", ErrDuplicateRelatedTransaction, dup) + } + for i, relatedTransaction := range relatedTransactions { if relatedTransaction.NetworkIdentifier != nil { if err := NetworkIdentifier(relatedTransaction.NetworkIdentifier); err != nil { @@ -401,6 +405,24 @@ func (a *Asserter) RelatedTransactions(relatedTransactions []*types.RelatedTrans return nil } +// DuplicateRelatedTransaction returns nil if no duplicates are found in the array and +// returns the first duplicated item found otherwise. +func DuplicateRelatedTransaction( + items []*types.RelatedTransaction, +) *types.RelatedTransaction { + seen := map[string]struct{}{} + for _, item := range items { + key := types.Hash(item) + if _, ok := seen[key]; ok { + return item + } + + seen[key] = struct{}{} + } + + return nil +} + // Direction returns an error if the value passed is not types.Forward or types.Backward func (a *Asserter) Direction(direction types.Direction) error { if direction != types.Forward && diff --git a/asserter/block_test.go b/asserter/block_test.go index 933fbe402..b45cb5f02 100644 --- a/asserter/block_test.go +++ b/asserter/block_test.go @@ -727,6 +727,59 @@ func TestBlock(t *testing.T) { }, }, } + duplicateRelatedTransactions := &types.Transaction{ + TransactionIdentifier: &types.TransactionIdentifier{ + Hash: "blah", + }, + Operations: []*types.Operation{ + { + OperationIdentifier: &types.OperationIdentifier{ + Index: int64(0), + }, + Type: "PAYMENT", + Status: types.String("SUCCESS"), + Account: validAccount, + Amount: validAmount, + }, + { + OperationIdentifier: &types.OperationIdentifier{ + Index: int64(1), + }, + RelatedOperations: []*types.OperationIdentifier{ + { + Index: int64(0), + }, + }, + Type: "PAYMENT", + Status: types.String("SUCCESS"), + Account: validAccount, + Amount: validAmount, + }, + }, + RelatedTransactions: []*types.RelatedTransaction{ + { + NetworkIdentifier: &types.NetworkIdentifier{ + Blockchain: "hello", + Network: "world", + }, + TransactionIdentifier: &types.TransactionIdentifier{ + Hash: "blah", + }, + Direction: types.Forward, + }, + { + NetworkIdentifier: &types.NetworkIdentifier{ + Blockchain: "hello", + Network: "world", + }, + TransactionIdentifier: &types.TransactionIdentifier{ + Hash: "blah", + }, + Direction: types.Forward, + }, + }, + } + var tests = map[string]struct { block *types.Block genesisIndex int64 @@ -908,6 +961,15 @@ func TestBlock(t *testing.T) { }, err: ErrInvalidDirection, }, + "duplicate related transaction": { + block: &types.Block{ + BlockIdentifier: validBlockIdentifier, + ParentBlockIdentifier: validParentBlockIdentifier, + Timestamp: MinUnixEpoch + 1, + Transactions: []*types.Transaction{duplicateRelatedTransactions}, + }, + err: ErrDuplicateRelatedTransaction, + }, } for name, test := range tests { diff --git a/asserter/errors.go b/asserter/errors.go index 7b9271b7b..8f0b588b0 100644 --- a/asserter/errors.go +++ b/asserter/errors.go @@ -91,7 +91,10 @@ var ( ErrBlockIndexPrecedesParentBlockIndex = errors.New( "BlockIdentifier.Index <= ParentBlockIdentifier.Index", ) - ErrInvalidDirection = errors.New("invalid direction (must be 'forward' or 'backward')") + ErrInvalidDirection = errors.New( + "invalid direction (must be 'forward' or 'backward')", + ) + ErrDuplicateRelatedTransaction = errors.New("duplicate related transaction") BlockErrs = []error{ ErrAmountValueMissing, @@ -127,6 +130,7 @@ var ( ErrBlockHashEqualsParentBlockHash, ErrBlockIndexPrecedesParentBlockIndex, ErrInvalidDirection, + ErrDuplicateRelatedTransaction, } ) diff --git a/storage/errors/errors.go b/storage/errors/errors.go index 25b3d5780..2cf2810c1 100644 --- a/storage/errors/errors.go +++ b/storage/errors/errors.go @@ -403,6 +403,8 @@ var ( ErrNothingToPrune = errors.New("nothing to prune") ErrPruningFailed = errors.New("pruning failed") ErrCannotPruneTransaction = errors.New("cannot prune transaction") + ErrCannotStoreBackwardRelation = errors.New("cannot store backward relation") + ErrCannotRemoveBackwardRelation = errors.New("cannot remove backward relation") BlockStorageErrs = []error{ ErrHeadBlockNotFound, @@ -438,6 +440,8 @@ var ( ErrNothingToPrune, ErrPruningFailed, ErrCannotPruneTransaction, + ErrCannotStoreBackwardRelation, + ErrCannotRemoveBackwardRelation, } ) diff --git a/storage/modules/block_storage.go b/storage/modules/block_storage.go index 8875921a4..490ba0486 100644 --- a/storage/modules/block_storage.go +++ b/storage/modules/block_storage.go @@ -54,6 +54,11 @@ const ( // blockSyncIdentifier is the identifier used to acquire // a database lock. blockSyncIdentifier = "blockSyncIdentifier" + + // backwardRelation is a relation from a child to a root transaction + // the root is the destination and the child is the transaction listing the root as a backward + // relation + backwardRelation = "backwardRelation" // prefix/root/child ) type blockTransaction struct { @@ -103,6 +108,19 @@ func getTransactionPrefix( ) } +// getBackwardRelationKey returns a db key for a backwards relation. passing nil in for the +// child returns a prefix key. +func getBackwardRelationKey( + backwardTransaction *types.TransactionIdentifier, + tx *types.TransactionIdentifier, +) []byte { + childHash := "" + if tx != nil { + childHash = tx.Hash + } + return []byte(fmt.Sprintf("%s/%s/%s", backwardRelation, backwardTransaction.Hash, childHash)) +} + // BlockWorker is an interface that allows for work // to be done while a block is added/removed from storage // in the same database transaction as the change. @@ -798,7 +816,7 @@ func (b *BlockStorage) RemoveBlock( gctx, transaction, blockIdentifier, - txn.TransactionIdentifier, + txn, ) }) } @@ -956,6 +974,11 @@ func (b *BlockStorage) storeTransaction( blockIdentifier *types.BlockIdentifier, tx *types.Transaction, ) error { + err := b.storeBackwardRelations(ctx, transaction, tx) + if err != nil { + return err + } + namespace, hashKey := getTransactionKey(blockIdentifier, tx.TransactionIdentifier) bt := &blockTransaction{ Transaction: tx, @@ -970,6 +993,81 @@ func (b *BlockStorage) storeTransaction( return storeUniqueKey(ctx, transaction, hashKey, encodedResult, true) } +func (b *BlockStorage) storeBackwardRelations( + ctx context.Context, + transaction database.Transaction, + tx *types.Transaction, +) error { + fn := func(ctx context.Context, transaction database.Transaction, key []byte) error { + err := transaction.Set(ctx, key, []byte{}, true) + if err != nil { + return fmt.Errorf("%v: %w", storageErrs.ErrCannotStoreBackwardRelation, err) + } + + return nil + } + + return b.modifyBackwardRelations(ctx, transaction, tx, fn) +} + +func (b *BlockStorage) removeBackwardRelations( + ctx context.Context, + transaction database.Transaction, + tx *types.Transaction, +) error { + fn := func(ctx context.Context, transaction database.Transaction, key []byte) error { + err := transaction.Delete(ctx, key) + if err != nil { + return fmt.Errorf("%w: %v", storageErrs.ErrCannotRemoveBackwardRelation, err) + } + + return nil + } + + return b.modifyBackwardRelations(ctx, transaction, tx, fn) +} + +func (b *BlockStorage) modifyBackwardRelations( + ctx context.Context, + transaction database.Transaction, + tx *types.Transaction, + fn func(ctx context.Context, transaction database.Transaction, key []byte) error, +) error { + var backwardRelationKeys [][]byte + for _, relatedTx := range tx.RelatedTransactions { + // skip if on another network + if relatedTx.NetworkIdentifier != nil { + continue + } + if relatedTx.Direction != types.Backward { + continue + } + + // skip if related block not found + block, _, err := b.FindTransaction(ctx, relatedTx.TransactionIdentifier, transaction) + if err != nil { + return fmt.Errorf("%v: %w", storageErrs.ErrCannotStoreBackwardRelation, err) + } + if block == nil { + continue + } + + backwardRelationKeys = append( + backwardRelationKeys, + getBackwardRelationKey(relatedTx.TransactionIdentifier, tx.TransactionIdentifier), + ) + } + + for _, key := range backwardRelationKeys { + err := fn(ctx, transaction, key) + if err != nil { + return err + } + } + + return nil +} + func (b *BlockStorage) pruneTransaction( ctx context.Context, transaction database.Transaction, @@ -993,10 +1091,14 @@ func (b *BlockStorage) removeTransaction( ctx context.Context, transaction database.Transaction, blockIdentifier *types.BlockIdentifier, - transactionIdentifier *types.TransactionIdentifier, + tx *types.Transaction, ) error { - _, hashKey := getTransactionKey(blockIdentifier, transactionIdentifier) + err := b.removeBackwardRelations(ctx, transaction, tx) + if err != nil { + return err + } + _, hashKey := getTransactionKey(blockIdentifier, tx.TransactionIdentifier) return transaction.Delete(ctx, hashKey) } @@ -1084,6 +1186,109 @@ func (b *BlockStorage) FindTransaction( return newestBlock, newestTransaction, nil } +func (b *BlockStorage) FindRelatedTransactions( + ctx context.Context, + transactionIdentifier *types.TransactionIdentifier, + db database.Transaction, +) (*types.BlockIdentifier, *types.Transaction, []*types.Transaction, error) { + rootBlock, tx, err := b.FindTransaction(ctx, transactionIdentifier, db) + if err != nil { + return nil, nil, nil, err + } + + if rootBlock == nil { + return nil, nil, nil, nil + } + + childIds, err := b.getForwardRelatedTransactions(ctx, tx, db) + if err != nil { + return nil, nil, nil, err + } + + // create map of seen transactions to avoid duplicates + seen := make(map[string]struct{}) + children := []*types.Transaction{} + + i := 0 + for { + if i >= len(childIds) { + break + } + childID := childIds[i] + i++ + + // skip duplicates + if _, ok := seen[childID.Hash]; !ok { + seen[childID.Hash] = struct{}{} + } else { + continue + } + + childBlock, childTx, err := b.FindTransaction(ctx, childID, db) + if err != nil { + return nil, nil, nil, err + } + + if childBlock == nil { + return nil, nil, nil, nil + } + + children = append(children, childTx) + if rootBlock.Index < childBlock.Index { + rootBlock = childBlock + } + + newChildren, err := b.getForwardRelatedTransactions(ctx, childTx, db) + if err != nil { + return nil, nil, nil, err + } + childIds = append(childIds, newChildren...) + } + + return rootBlock, tx, children, nil +} + +// TODO: add support for relations across multiple networks +func (b *BlockStorage) getForwardRelatedTransactions( + ctx context.Context, + tx *types.Transaction, + db database.Transaction, +) ([]*types.TransactionIdentifier, error) { + var children []*types.TransactionIdentifier + for _, relatedTx := range tx.RelatedTransactions { + // skip if on another network + if relatedTx.NetworkIdentifier != nil { + continue + } + + if relatedTx.Direction == types.Forward { + children = append(children, relatedTx.TransactionIdentifier) + } + } + + // scan db for all transactions where tx appears as a backward relation + _, err := db.Scan( + ctx, + getBackwardRelationKey(tx.TransactionIdentifier, nil), + getBackwardRelationKey(tx.TransactionIdentifier, nil), + func(k []byte, v []byte) error { + ss := strings.Split(string(k), "/") + txHash := ss[len(ss)-1] + txID := &types.TransactionIdentifier{Hash: txHash} + children = append(children, txID) + return nil + }, + false, + false, + ) + + if err != nil { + return nil, err + } + + return children, nil +} + func (b *BlockStorage) findBlockTransaction( ctx context.Context, blockIdentifier *types.BlockIdentifier, diff --git a/storage/modules/block_storage_test.go b/storage/modules/block_storage_test.go index 08c92ce6f..4fcc2b58f 100644 --- a/storage/modules/block_storage_test.go +++ b/storage/modules/block_storage_test.go @@ -128,6 +128,23 @@ func simpleTransactionFactory( } } +func addRelatedTransaction( + transaction *types.Transaction, + hash string, + direction types.Direction, +) *types.Transaction { + relatedTx := &types.RelatedTransaction{ + NetworkIdentifier: nil, + TransactionIdentifier: &types.TransactionIdentifier{ + Hash: hash, + }, + Direction: direction, + } + + transaction.RelatedTransactions = append(transaction.RelatedTransactions, relatedTx) + return transaction +} + var ( genesisBlock = &types.Block{ BlockIdentifier: &types.BlockIdentifier{ @@ -902,3 +919,139 @@ func TestAtTip(t *testing.T) { assert.True(t, atTip) }) } + +func TestRelatedTransactions(t *testing.T) { + // setup + ctx := context.Background() + + newDir, err := utils.CreateTempDir() + assert.NoError(t, err) + defer utils.RemoveTempDir(newDir) + + database, err := newTestBadgerDatabase(ctx, newDir) + assert.NoError(t, err) + defer database.Close(ctx) + + storage := NewBlockStorage(database, blockWorkerConcurrency) + + t.Run("test forward and backward relations", func(t *testing.T) { + err = storage.SeeBlock(ctx, genesisBlock) + assert.NoError(t, err) + err = storage.AddBlock(ctx, genesisBlock) + assert.NoError(t, err) + + block1 := &types.Block{ + BlockIdentifier: &types.BlockIdentifier{ + Hash: "blah 1", + Index: 1, + }, + ParentBlockIdentifier: &types.BlockIdentifier{ + Hash: "blah 0", + Index: 0, + }, + Timestamp: 1, + Transactions: []*types.Transaction{ + addRelatedTransaction( + simpleTransactionFactory( + "parentTx", + "addr1", + "100", + &types.Currency{Symbol: "hello"}, + ), + "childTx", + types.Forward, + ), + simpleTransactionFactory( + "backwardRelative", + "addr2", + "100", + &types.Currency{Symbol: "hello"}, + ), + }, + } + err = storage.SeeBlock(ctx, block1) + assert.NoError(t, err) + err = storage.AddBlock(ctx, block1) + assert.NoError(t, err) + + block2 := &types.Block{ + BlockIdentifier: &types.BlockIdentifier{ + Hash: "blah 2", + Index: 2, + }, + ParentBlockIdentifier: &types.BlockIdentifier{ + Hash: "blah 1", + Index: 1, + }, + Timestamp: 1, + Transactions: []*types.Transaction{ + simpleTransactionFactory( + "childTx", + "addr3", + "100", + &types.Currency{Symbol: "hello"}, + ), + addRelatedTransaction( + simpleTransactionFactory( + "backwardTx", + "addr4", + "100", + &types.Currency{Symbol: "hello"}, + ), + "backwardRelative", + types.Backward, + ), + addRelatedTransaction( + simpleTransactionFactory( + "badForward", + "addr5", + "100", + &types.Currency{Symbol: "hello"}, + ), + "invalid", + types.Forward, + ), + }, + } + err = storage.SeeBlock(ctx, block2) + assert.NoError(t, err) + err = storage.AddBlock(ctx, block2) + assert.NoError(t, err) + + _, _, related, err := storage.FindRelatedTransactions( + ctx, + block1.Transactions[0].TransactionIdentifier, + storage.db.ReadTransaction(ctx), + ) + assert.NoError(t, err) + assert.Equal(t, len(related), 1) + assert.Equal( + t, + related[0].TransactionIdentifier.Hash, + block2.Transactions[0].TransactionIdentifier.Hash, + ) + + _, _, related, err = storage.FindRelatedTransactions( + ctx, + block1.Transactions[1].TransactionIdentifier, + storage.db.ReadTransaction(ctx), + ) + assert.NoError(t, err) + assert.Equal(t, len(related), 1) + assert.Equal( + t, + related[0].TransactionIdentifier.Hash, + block2.Transactions[1].TransactionIdentifier.Hash, + ) + + blockID, tx, related, err := storage.FindRelatedTransactions( + ctx, + block2.Transactions[2].TransactionIdentifier, + storage.db.ReadTransaction(ctx), + ) + assert.NoError(t, err) + assert.Nil(t, blockID) + assert.Nil(t, tx) + assert.Empty(t, related) + }) +}