diff --git a/services/wallet/async.go b/services/wallet/async.go new file mode 100644 index 0000000000..fbc551de9b --- /dev/null +++ b/services/wallet/async.go @@ -0,0 +1,75 @@ +package wallet + +import ( + "context" + "sync" + "time" +) + +type Command interface { + Run(context.Context) +} + +type FiniteCommand struct { + Interval time.Duration + Runable func(context.Context) error +} + +func (c FiniteCommand) Run(ctx context.Context) { + ticker := time.NewTicker(c.Interval) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := c.Runable(ctx) + if err == nil { + return + } + } + } +} + +type InfiniteCommand struct { + Interval time.Duration + Runable func(context.Context) error +} + +func (c InfiniteCommand) Run(ctx context.Context) { + ticker := time.NewTicker(c.Interval) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + _ = c.Runable(ctx) + } + } +} + +func NewGroup() *Group { + ctx, cancel := context.WithCancel(context.Background()) + return &Group{ + ctx: ctx, + cancel: cancel, + } +} + +type Group struct { + ctx context.Context + cancel func() + wg sync.WaitGroup +} + +func (g *Group) Add(cmd Command) { + g.wg.Add(1) + go func() { + cmd.Run(g.ctx) + g.wg.Done() + }() +} + +func (g *Group) Stop() { + g.cancel() + g.wg.Wait() +} diff --git a/services/wallet/commands.go b/services/wallet/commands.go index 07c4d2f875..e9071cce4d 100644 --- a/services/wallet/commands.go +++ b/services/wallet/commands.go @@ -2,8 +2,8 @@ package wallet import ( "context" + "errors" "math/big" - "sync" "time" "github.com/ethereum/go-ethereum/common" @@ -12,74 +12,6 @@ import ( "github.com/ethereum/go-ethereum/log" ) -type Command interface { - Run(context.Context) -} - -type FiniteCommand struct { - Interval time.Duration - Runable func(context.Context) error -} - -func (c FiniteCommand) Run(ctx context.Context) { - ticker := time.NewTicker(c.Interval) - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - err := c.Runable(ctx) - if err == nil { - return - } - } - } -} - -type InfiniteCommand struct { - Interval time.Duration - Runable func(context.Context) error -} - -func (c InfiniteCommand) Run(ctx context.Context) { - ticker := time.NewTicker(c.Interval) - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - _ = c.Runable(ctx) - } - } -} - -func NewGroup() *Group { - ctx, cancel := context.WithCancel(context.Background()) - return &Group{ - ctx: ctx, - cancel: cancel, - } -} - -type Group struct { - ctx context.Context - cancel func() - wg sync.WaitGroup -} - -func (g *Group) Add(cmd Command) { - g.wg.Add(1) - go func() { - cmd.Run(g.ctx) - g.wg.Done() - }() -} - -func (g *Group) Stop() { - g.cancel() - g.wg.Wait() -} - type ethHistoricalCommand struct { db *Database eth TransferDownloader @@ -115,12 +47,17 @@ func (c *ethHistoricalCommand) Run(ctx context.Context) (err error) { concurrent := NewConcurrentDownloader(ctx) start := time.Now() downloadEthConcurrently(concurrent, c.client, c.eth, c.address, zero, c.previous.Number) - concurrent.Wait() + select { + case <-concurrent.WaitAsync(): + case <-ctx.Done(): + log.Error("eth downloader is stuck") + return errors.New("eth downloader is stuck") + } if concurrent.Error() != nil { log.Error("failed to dowloader transfers using concurrent downloader", "error", err) return concurrent.Error() } - transfers := concurrent.Transfers() + transfers := concurrent.Get() log.Info("eth historical downloader finished succesfully", "total transfers", len(transfers), "time", time.Since(start)) // TODO(dshulyak) insert 0 block number with transfers err = c.db.ProcessTranfers(transfers, headersFromTransfers(transfers), nil, ethSync) diff --git a/services/wallet/concurrent.go b/services/wallet/concurrent.go index 2c0d1f6881..6d071326e8 100644 --- a/services/wallet/concurrent.go +++ b/services/wallet/concurrent.go @@ -11,30 +11,59 @@ import ( // NewConcurrentDownloader creates ConcurrentDownloader instance. func NewConcurrentDownloader(ctx context.Context) *ConcurrentDownloader { + runner := NewConcurrentRunner(ctx) + result := &Result{} + return &ConcurrentDownloader{runner, result} +} + +type ConcurrentDownloader struct { + *ConcurrentRunner + *Result +} + +type Result struct { + mu sync.Mutex + transfers []Transfer +} + +func (r *Result) Add(transfers ...Transfer) { + r.mu.Lock() + defer r.mu.Unlock() + r.transfers = append(r.transfers, transfers...) +} + +func (r *Result) Get() []Transfer { + r.mu.Lock() + defer r.mu.Unlock() + rst := make([]Transfer, len(r.transfers)) + copy(rst, r.transfers) + return rst +} + +func NewConcurrentRunner(ctx context.Context) *ConcurrentRunner { ctx, cancel := context.WithCancel(ctx) - return &ConcurrentDownloader{ + return &ConcurrentRunner{ ctx: ctx, cancel: cancel, } } -// ConcurrentDownloader manages downloaders life cycle. -type ConcurrentDownloader struct { +// ConcurrentRunner runs group atomically. +type ConcurrentRunner struct { ctx context.Context cancel func() wg sync.WaitGroup - mu sync.Mutex - results []Transfer - error error + mu sync.Mutex + error error } // Go spawns function in a goroutine and stores results or errors. -func (d *ConcurrentDownloader) Go(f func(context.Context) ([]Transfer, error)) { +func (d *ConcurrentRunner) Go(f func(context.Context) error) { d.wg.Add(1) go func() { defer d.wg.Done() - transfers, err := f(d.ctx) + err := f(d.ctx) d.mu.Lock() defer d.mu.Unlock() if err != nil { @@ -46,29 +75,30 @@ func (d *ConcurrentDownloader) Go(f func(context.Context) ([]Transfer, error)) { d.cancel() return } - d.results = append(d.results, transfers...) }() } -// Transfers returns collected transfers. To get all results should be called after Wait. -func (d *ConcurrentDownloader) Transfers() []Transfer { - d.mu.Lock() - defer d.mu.Unlock() - rst := make([]Transfer, len(d.results)) - copy(rst, d.results) - return rst -} - // Wait for all downloaders to finish. -func (d *ConcurrentDownloader) Wait() { +func (d *ConcurrentRunner) Wait() { d.wg.Wait() if d.Error() == nil { + d.mu.Lock() + defer d.mu.Unlock() d.cancel() } } +func (d *ConcurrentRunner) WaitAsync() <-chan struct{} { + ch := make(chan struct{}) + go func() { + d.Wait() + close(ch) + }() + return ch +} + // Error stores an error that was reported by any of the downloader. Should be called after Wait. -func (d *ConcurrentDownloader) Error() error { +func (d *ConcurrentRunner) Error() error { d.mu.Lock() defer d.mu.Unlock() return d.error @@ -80,29 +110,34 @@ type TransferDownloader interface { } func downloadEthConcurrently(c *ConcurrentDownloader, client BalanceReader, downloader TransferDownloader, account common.Address, low, high *big.Int) { - c.Go(func(ctx context.Context) ([]Transfer, error) { + c.Go(func(ctx context.Context) error { log.Debug("eth transfers comparing blocks", "low", low, "high", high) lb, err := client.BalanceAt(ctx, account, low) if err != nil { - return nil, err + return err } hb, err := client.BalanceAt(ctx, account, high) if err != nil { - return nil, err + return err } if lb.Cmp(hb) == 0 { log.Debug("balances are equal", "low", low, "high", high) - return nil, nil + return nil } if new(big.Int).Sub(high, low).Cmp(one) == 0 { log.Debug("higher block is a parent", "low", low, "high", high) - return downloader.GetTransfersByNumber(ctx, high) + transfers, err := downloader.GetTransfersByNumber(ctx, high) + if err != nil { + return err + } + c.Add(transfers...) + return nil } mid := new(big.Int).Add(low, high) mid = mid.Div(mid, two) log.Debug("balances are not equal spawn two concurrent downloaders", "low", low, "mid", mid, "high", high) downloadEthConcurrently(c, client, downloader, account, low, mid) downloadEthConcurrently(c, client, downloader, account, mid, high) - return nil, nil + return nil }) } diff --git a/services/wallet/concurrent_test.go b/services/wallet/concurrent_test.go index 859a41497b..e2ba6a400a 100644 --- a/services/wallet/concurrent_test.go +++ b/services/wallet/concurrent_test.go @@ -15,17 +15,17 @@ import ( func TestConcurrentErrorInterrupts(t *testing.T) { concurrent := NewConcurrentDownloader(context.Background()) var interrupted bool - concurrent.Go(func(ctx context.Context) ([]Transfer, error) { + concurrent.Go(func(ctx context.Context) error { select { case <-ctx.Done(): interrupted = true case <-time.After(10 * time.Second): } - return nil, nil + return nil }) err := errors.New("interrupt") - concurrent.Go(func(ctx context.Context) ([]Transfer, error) { - return nil, err + concurrent.Go(func(ctx context.Context) error { + return err }) concurrent.Wait() require.True(t, interrupted) @@ -34,14 +34,16 @@ func TestConcurrentErrorInterrupts(t *testing.T) { func TestConcurrentCollectsTransfers(t *testing.T) { concurrent := NewConcurrentDownloader(context.Background()) - concurrent.Go(func(context.Context) ([]Transfer, error) { - return []Transfer{{}}, nil + concurrent.Go(func(context.Context) error { + concurrent.Add(Transfer{}) + return nil }) - concurrent.Go(func(context.Context) ([]Transfer, error) { - return []Transfer{{}}, nil + concurrent.Go(func(context.Context) error { + concurrent.Add(Transfer{}) + return nil }) concurrent.Wait() - require.Len(t, concurrent.Transfers(), 2) + require.Len(t, concurrent.Get(), 2) } type balancesFixture []*big.Int @@ -111,7 +113,7 @@ func TestConcurrentEthDownloader(t *testing.T) { common.Address{}, zero, tc.options.last) concurrent.Wait() require.NoError(t, concurrent.Error()) - rst := concurrent.Transfers() + rst := concurrent.Get() require.Len(t, rst, len(tc.options.result)) sort.Slice(rst, func(i, j int) bool { return rst[i].BlockNumber.Cmp(rst[j].BlockNumber) < 0 diff --git a/services/wallet/downloader.go b/services/wallet/downloader.go index cc9b95b9f7..24a4f4f594 100644 --- a/services/wallet/downloader.go +++ b/services/wallet/downloader.go @@ -2,6 +2,7 @@ package wallet import ( "context" + "errors" "math/big" "time" @@ -163,32 +164,48 @@ func (d *ERC20TransfersDownloader) outboundTopics(address common.Address) [][]co return [][]common.Hash{{d.signature}, {d.paddedAddress(address)}, {}} } +func (d *ERC20TransfersDownloader) tranasferFromLogs(parent context.Context, log types.Log, address common.Address) (Transfer, error) { + ctx, cancel := context.WithTimeout(parent, 3*time.Second) + tx, _, err := d.client.TransactionByHash(ctx, log.TxHash) + cancel() + if err != nil { + return Transfer{}, err + } + ctx, cancel = context.WithTimeout(parent, 3*time.Second) + receipt, err := d.client.TransactionReceipt(ctx, log.TxHash) + cancel() + if err != nil { + return Transfer{}, err + } + return Transfer{ + Address: address, + Type: erc20Transfer, + BlockNumber: new(big.Int).SetUint64(log.BlockNumber), + BlockHash: log.BlockHash, + Transaction: tx, + Receipt: receipt, + }, nil +} + func (d *ERC20TransfersDownloader) transfersFromLogs(parent context.Context, logs []types.Log, address common.Address) ([]Transfer, error) { - rst := make([]Transfer, len(logs)) - for i, l := range logs { - // TODO(dshulyak) use TransactionInBlock after it is fixed - ctx, cancel := context.WithTimeout(parent, 3*time.Second) - tx, _, err := d.client.TransactionByHash(ctx, l.TxHash) - cancel() - if err != nil { - return nil, err - } - ctx, cancel = context.WithTimeout(parent, 3*time.Second) - receipt, err := d.client.TransactionReceipt(ctx, l.TxHash) - cancel() - if err != nil { - return nil, err - } - rst[i] = Transfer{ - Address: address, - Type: erc20Transfer, - BlockNumber: new(big.Int).SetUint64(l.BlockNumber), - BlockHash: l.BlockHash, - Transaction: tx, - Receipt: receipt, - } + concurrent := NewConcurrentDownloader(parent) + for i := range logs { + l := logs[i] + concurrent.Go(func(ctx context.Context) error { + transfer, err := d.tranasferFromLogs(ctx, l, address) + if err != nil { + return err + } + concurrent.Add(transfer) + return nil + }) } - return rst, nil + select { + case <-concurrent.WaitAsync(): + case <-parent.Done(): + return nil, errors.New("logs downloader stuck") + } + return concurrent.Get(), nil } func any(address common.Address, compare []common.Address) bool {