From c4be5cb111f6f16520bf6ed95273c115006908aa Mon Sep 17 00:00:00 2001 From: Lee Date: Wed, 4 Dec 2024 16:31:45 +0800 Subject: [PATCH 1/3] Support batch speedup of messages --- api.go | 2 +- model.go | 1 + service.go | 102 +++++++++++++++++++++++++++++++++++++++++------------ 3 files changed, 82 insertions(+), 23 deletions(-) diff --git a/api.go b/api.go index ced7939..a53b0ed 100644 --- a/api.go +++ b/api.go @@ -154,7 +154,7 @@ func (a *implAPI) speedup(w http.ResponseWriter, r *http.Request) { } } - err = a.srv.speedupRequest(r.Context(), uint(id), mss) + err = a.srv.speedupRequest(uint(id), mss) if err != nil { warpResponse(w, http.StatusBadRequest, nil, err) return diff --git a/model.go b/model.go index 8524dd6..b9cdff4 100644 --- a/model.go +++ b/model.go @@ -79,6 +79,7 @@ type Message struct { ID uint `gorm:"primarykey"` Cid CID `gorm:"index"` // The unique identifier of the message Extensions []Extension2 // The list of extensions associated with the message + Nonce uint64 RequestID uint OnChain bool ExitCode exitcode.ExitCode diff --git a/service.go b/service.go index dd073a0..0554efb 100644 --- a/service.go +++ b/service.go @@ -29,6 +29,7 @@ import ( "github.com/filecoin-project/lotus/chain/messagepool" "github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/lotus/node/config" + "github.com/ipfs/go-cid" cbor "github.com/ipfs/go-ipld-cbor" "github.com/samber/lo" "go.uber.org/zap" @@ -60,6 +61,11 @@ func (w *watchMessage) Cancel() { close(w.cancelCh) } +type speedupMessage struct { + msg *Message + mss *api.MessageSendSpec +} + type Service struct { api API adtStore adt.Store @@ -67,6 +73,7 @@ type Service struct { maxWait time.Duration wg sync.WaitGroup watchingMessages *SafeMap[uint, *watchMessage] + speedupMessages chan *speedupMessage shutdownFunc context.CancelFunc } @@ -82,14 +89,16 @@ func NewService(ctx context.Context, db *gorm.DB, api API, maxWait time.Duration adtStore: adtStore, watchingMessages: NewSafeMap[uint, *watchMessage](), maxWait: maxWait, + speedupMessages: make(chan *speedupMessage), shutdownFunc: func() { cancel() }, } - s.wg.Add(3) + s.wg.Add(4) go s.runProcessor(ctx) go s.runMessageChecker(ctx) go s.runPendingChecker(ctx) + go s.runSpeedupWorker(ctx) return s } @@ -611,6 +620,7 @@ loopParams: Cid: CID{smsg.Cid()}, Extensions: exts, Sectors: scount, + Nonce: smsg.Message.Nonce, } s.db.Create(msg) messages = append(messages, msg) @@ -703,7 +713,7 @@ func buildParams(l miner.SectorLocation, newExp abi.ChainEpoch, numbers []abi.Se return &e2, cannotExtendSectors, sectorsInDecl, nil } -func (s *Service) speedupRequest(ctx context.Context, id uint, mss *api.MessageSendSpec) error { +func (s *Service) speedupRequest(id uint, mss *api.MessageSendSpec) error { var request Request if err := s.db.Preload(clause.Associations).First(&request, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -715,16 +725,47 @@ func (s *Service) speedupRequest(ctx context.Context, id uint, mss *api.MessageS return fmt.Errorf("request is not pending") } - for _, msg := range request.Messages { - if msg.OnChain { - continue + sort.Slice(request.Messages, func(i, j int) bool { + return request.Messages[i].Nonce < request.Messages[j].Nonce + }) + + go func() { + for _, msg := range request.Messages { + if msg.OnChain { + continue + } + s.speedupMessages <- &speedupMessage{ + msg: msg, + mss: mss, + } } - // todo: need order by nonce? - if err := s.replaceMessage(ctx, msg.ID, mss); err != nil { - return fmt.Errorf("failed to replace message: %w", err) + }() + return nil +} + +func (s *Service) runSpeedupWorker(ctx context.Context) { + defer s.wg.Done() + log.Info("starting speedup worker") + + sem := make(chan struct{}, 10) + + for { + select { + case <-ctx.Done(): + log.Info("context done, stopping speedup worker") + return + case sm := <-s.speedupMessages: + sem <- struct{}{} + go func(sm *speedupMessage) { + defer func() { + <-sem + }() + if err := s.replaceMessageAndWait(ctx, sm.msg.ID, sm.mss); err != nil { + log.Errorf("failed to replace message: %s", err) + } + }(sm) } } - return nil } func (s *Service) checkMessage(ctx context.Context, request *Request) error { @@ -813,7 +854,7 @@ func (s *Service) watchMessage(ctx context.Context, id uint) { errorChan := make(chan error, 1) go func() { - receipt, err := s.api.StateWaitMsg(ctx, msg.Cid.Cid, 2*build.MessageConfidence, api.LookbackNoLimit, true) + receipt, err := s.api.StateWaitMsg(ctx, msg.Cid.Cid, build.MessageConfidence, api.LookbackNoLimit, true) if err != nil { errorChan <- err return @@ -872,7 +913,7 @@ func (s *Service) runPendingChecker(ctx context.Context) { MaxFee: abi.TokenAmount(maxFee), } for _, id := range replaceMessages { - if err := s.replaceMessage(ctx, id, mss); err != nil { + if _, err := s.replaceMessage(ctx, id, mss); err != nil { log.Errorf("failed to replace message: %s", err) } } @@ -881,13 +922,29 @@ func (s *Service) runPendingChecker(ctx context.Context) { } } -func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageSendSpec) error { +func (s *Service) replaceMessageAndWait(ctx context.Context, id uint, mss *api.MessageSendSpec) error { + nid, err := s.replaceMessage(ctx, id, mss) + if err != nil { + return err + } + + var msg Message + if err := s.db.Preload(clause.Associations). + Where("cid = ?", nid). + First(&msg).Error; err != nil { + return fmt.Errorf("failed to get message: %w", err) + } + s.watchMessage(ctx, msg.ID) + return nil +} + +func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageSendSpec) (*cid.Cid, error) { var m Message if err := s.db.Preload(clause.Associations).First(&m, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return fmt.Errorf("message id not found in db: %d", id) + return nil, fmt.Errorf("message id not found in db: %d", id) } - return fmt.Errorf("failed to get request: %w", err) + return nil, fmt.Errorf("failed to get request: %w", err) } sLog := log.With("request", m.RequestID, "id", id, "cid", m.Cid.String()) sLog.Info("replacing message") @@ -895,17 +952,17 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS // get the message from the chain cm, err := s.api.ChainGetMessage(ctx, m.Cid.Cid) if err != nil { - return fmt.Errorf("could not find referenced message: %w", err) + return nil, fmt.Errorf("could not find referenced message: %w", err) } ts, err := s.api.ChainHead(ctx) if err != nil { - return fmt.Errorf("getting chain head: %w", err) + return nil, fmt.Errorf("getting chain head: %w", err) } pending, err := s.api.MpoolPending(ctx, ts.Key()) if err != nil { - return err + return nil, err } var found *types.SignedMessage @@ -919,13 +976,13 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS // If the message is not found in the mpool, skip it and continue with the next one if found == nil { sLog.Warn("message not found in mpool, skipping") - return nil + return nil, nil } msg := found.Message cfg, err := s.api.MpoolGetConfig(ctx) if err != nil { - return fmt.Errorf("failed to lookup the message pool config: %w", err) + return nil, fmt.Errorf("failed to lookup the message pool config: %w", err) } defaultRBF := messagepool.ComputeRBF(msg.GasPremium, cfg.ReplaceByFeeRatio) @@ -935,7 +992,7 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS msg.GasPremium = abi.NewTokenAmount(0) ret, err := s.api.GasEstimateMessageGas(ctx, &msg, mss, types.EmptyTSK) if err != nil { - return fmt.Errorf("failed to estimate gas values: %w", err) + return nil, fmt.Errorf("failed to estimate gas values: %w", err) } msg.GasPremium = big.Max(ret.GasPremium, defaultRBF) msg.GasFeeCap = big.Max(ret.GasFeeCap, msg.GasPremium) @@ -948,10 +1005,10 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS smsg, err := s.api.WalletSignMessage(ctx, msg.From, &msg) if err != nil { - return fmt.Errorf("failed to sign message: %w", err) + return nil, fmt.Errorf("failed to sign message: %w", err) } - return s.db.Transaction(func(tx *gorm.DB) error { + return lo.ToPtr(smsg.Cid()), s.db.Transaction(func(tx *gorm.DB) error { if err := tx.Delete(&m).Error; err != nil { return err } @@ -965,6 +1022,7 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS Extensions: m.Extensions, RequestID: m.RequestID, Sectors: m.Sectors, + Nonce: smsg.Message.Nonce, } if err := tx.Create(newMsg).Error; err != nil { From 42a7b023929247107579f9ae110d2e828ae28110 Mon Sep 17 00:00:00 2001 From: Lee Date: Wed, 11 Dec 2024 10:51:49 +0800 Subject: [PATCH 2/3] add batch flag --- cmd.go | 7 ++++++- service.go | 50 +++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/cmd.go b/cmd.go index bf38dbf..9075aab 100644 --- a/cmd.go +++ b/cmd.go @@ -112,6 +112,11 @@ var runCmd = &cli.Command{ Name: "max-wait", Usage: "[Warning] specify the maximum time to wait for messages on chain, otherwise try to replace them, only use this if you know what you are doing", }, + &cli.Int64Flag{ + Name: "batch-speedup", + Usage: "specify the number of messages to speed up in a single batch", + Value: 10, + }, &cli.BoolFlag{ Name: "debug", Value: false, @@ -172,7 +177,7 @@ var runCmd = &cli.Command{ authStatus = "enabled" } - service := NewService(ctx, db, fullApi, cctx.Duration("max-wait")) + service := NewService(ctx, db, fullApi, cctx.Duration("max-wait"), cctx.Int64("batch-speedup")) srv := &http.Server{ Handler: NewRouter(service, secret), Addr: cctx.String("listen"), diff --git a/service.go b/service.go index 0554efb..428e476 100644 --- a/service.go +++ b/service.go @@ -29,6 +29,7 @@ import ( "github.com/filecoin-project/lotus/chain/messagepool" "github.com/filecoin-project/lotus/chain/types" "github.com/filecoin-project/lotus/node/config" + "github.com/ipfs/go-cid" cbor "github.com/ipfs/go-ipld-cbor" "github.com/samber/lo" @@ -43,15 +44,15 @@ const ( type watchMessage struct { id uint - cid CID + db *gorm.DB started time.Time cancelCh chan struct{} } -func newWatchMessage(m *Message) *watchMessage { +func newWatchMessage(db *gorm.DB, id uint) *watchMessage { return &watchMessage{ - id: m.ID, - cid: m.Cid, + db: db, + id: id, started: time.Now(), cancelCh: make(chan struct{}), } @@ -61,6 +62,27 @@ func (w *watchMessage) Cancel() { close(w.cancelCh) } +func (w *watchMessage) Wait() { + tk := time.NewTicker(5 * time.Second) + defer tk.Stop() + + for { + select { + case <-w.cancelCh: + return + case <-tk.C: + var msg Message + if err := w.db.First(&msg, w.id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return + } else { + log.Errorf("failed to get message: %s", err) + } + } + } + } +} + type speedupMessage struct { msg *Message mss *api.MessageSendSpec @@ -74,10 +96,11 @@ type Service struct { wg sync.WaitGroup watchingMessages *SafeMap[uint, *watchMessage] speedupMessages chan *speedupMessage + batchSpeedup int64 shutdownFunc context.CancelFunc } -func NewService(ctx context.Context, db *gorm.DB, api API, maxWait time.Duration) *Service { +func NewService(ctx context.Context, db *gorm.DB, api API, maxWait time.Duration, batchSpeedup int64) *Service { tbs := blockstore.NewTieredBstore(blockstore.NewAPIBlockstore(api), blockstore.NewMemory()) adtStore := adt.WrapStore(ctx, cbor.NewCborStore(tbs)) @@ -90,6 +113,7 @@ func NewService(ctx context.Context, db *gorm.DB, api API, maxWait time.Duration watchingMessages: NewSafeMap[uint, *watchMessage](), maxWait: maxWait, speedupMessages: make(chan *speedupMessage), + batchSpeedup: batchSpeedup, shutdownFunc: func() { cancel() }, @@ -747,7 +771,7 @@ func (s *Service) runSpeedupWorker(ctx context.Context) { defer s.wg.Done() log.Info("starting speedup worker") - sem := make(chan struct{}, 10) + sem := make(chan struct{}, s.batchSpeedup) for { select { @@ -844,9 +868,12 @@ func (s *Service) watchMessage(ctx context.Context, id uint) { sLog.Warn("message is already watching") return } - wm := newWatchMessage(&msg) + wm := newWatchMessage(s.db, msg.ID) s.watchingMessages.Set(msg.ID, wm) - defer s.watchingMessages.Delete(msg.ID) + defer func() { + wm.Cancel() + s.watchingMessages.Delete(msg.ID) + }() sLog.Info("watching message") @@ -934,7 +961,12 @@ func (s *Service) replaceMessageAndWait(ctx context.Context, id uint, mss *api.M First(&msg).Error; err != nil { return fmt.Errorf("failed to get message: %w", err) } - s.watchMessage(ctx, msg.ID) + + if wm, ok := s.watchingMessages.Get(msg.ID); ok { + wm.Wait() + } else { + s.watchMessage(ctx, msg.ID) + } return nil } From 2c8c3e19e6a1782de5c12bf683533328f6e14434 Mon Sep 17 00:00:00 2001 From: Lee Date: Wed, 11 Dec 2024 11:03:17 +0800 Subject: [PATCH 3/3] fix: channel close once --- service.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/service.go b/service.go index 428e476..161b61e 100644 --- a/service.go +++ b/service.go @@ -43,10 +43,11 @@ const ( ) type watchMessage struct { - id uint - db *gorm.DB - started time.Time - cancelCh chan struct{} + id uint + db *gorm.DB + started time.Time + cancelOnce sync.Once + cancelCh chan struct{} } func newWatchMessage(db *gorm.DB, id uint) *watchMessage { @@ -59,7 +60,9 @@ func newWatchMessage(db *gorm.DB, id uint) *watchMessage { } func (w *watchMessage) Cancel() { - close(w.cancelCh) + w.cancelOnce.Do(func() { + close(w.cancelCh) + }) } func (w *watchMessage) Wait() {