Skip to content

Commit

Permalink
Support batch speedup of messages (#23)
Browse files Browse the repository at this point in the history
* Support batch speedup of messages

* add batch flag

* fix: channel close once
  • Loading branch information
strahe authored Dec 12, 2024
1 parent b1e7c40 commit 8061ab7
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 35 deletions.
2 changes: 1 addition & 1 deletion api.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (a *implAPI) speedup(w http.ResponseWriter, r *http.Request) {
}
}

err = a.srv.speedupRequest(r.Context(), uint(id), mss)
err = a.srv.speedupRequest(uint(id), mss)
if err != nil {
warpResponse(w, http.StatusBadRequest, nil, err)
return
Expand Down
7 changes: 6 additions & 1 deletion cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ var runCmd = &cli.Command{
Name: "max-wait",
Usage: "[Warning] specify the maximum time to wait for messages on chain, otherwise try to replace them, only use this if you know what you are doing",
},
&cli.Int64Flag{
Name: "batch-speedup",
Usage: "specify the number of messages to speed up in a single batch",
Value: 10,
},
&cli.BoolFlag{
Name: "debug",
Value: false,
Expand Down Expand Up @@ -172,7 +177,7 @@ var runCmd = &cli.Command{
authStatus = "enabled"
}

service := NewService(ctx, db, fullApi, cctx.Duration("max-wait"))
service := NewService(ctx, db, fullApi, cctx.Duration("max-wait"), cctx.Int64("batch-speedup"))
srv := &http.Server{
Handler: NewRouter(service, secret),
Addr: cctx.String("listen"),
Expand Down
1 change: 1 addition & 0 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ type Message struct {
ID uint `gorm:"primarykey"`
Cid CID `gorm:"index"` // The unique identifier of the message
Extensions []Extension2 // The list of extensions associated with the message
Nonce uint64
RequestID uint
OnChain bool
ExitCode exitcode.ExitCode
Expand Down
159 changes: 126 additions & 33 deletions service.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
"github.com/filecoin-project/lotus/chain/messagepool"
"github.com/filecoin-project/lotus/chain/types"
"github.com/filecoin-project/lotus/node/config"

"github.com/ipfs/go-cid"
cbor "github.com/ipfs/go-ipld-cbor"
"github.com/samber/lo"
"go.uber.org/zap"
Expand All @@ -41,23 +43,52 @@ const (
)

type watchMessage struct {
id uint
cid CID
started time.Time
cancelCh chan struct{}
id uint
db *gorm.DB
started time.Time
cancelOnce sync.Once
cancelCh chan struct{}
}

func newWatchMessage(m *Message) *watchMessage {
func newWatchMessage(db *gorm.DB, id uint) *watchMessage {
return &watchMessage{
id: m.ID,
cid: m.Cid,
db: db,
id: id,
started: time.Now(),
cancelCh: make(chan struct{}),
}
}

func (w *watchMessage) Cancel() {
close(w.cancelCh)
w.cancelOnce.Do(func() {
close(w.cancelCh)
})
}

func (w *watchMessage) Wait() {
tk := time.NewTicker(5 * time.Second)
defer tk.Stop()

for {
select {
case <-w.cancelCh:
return
case <-tk.C:
var msg Message
if err := w.db.First(&msg, w.id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return
} else {
log.Errorf("failed to get message: %s", err)
}
}
}
}
}

type speedupMessage struct {
msg *Message
mss *api.MessageSendSpec
}

type Service struct {
Expand All @@ -67,10 +98,12 @@ type Service struct {
maxWait time.Duration
wg sync.WaitGroup
watchingMessages *SafeMap[uint, *watchMessage]
speedupMessages chan *speedupMessage
batchSpeedup int64
shutdownFunc context.CancelFunc
}

func NewService(ctx context.Context, db *gorm.DB, api API, maxWait time.Duration) *Service {
func NewService(ctx context.Context, db *gorm.DB, api API, maxWait time.Duration, batchSpeedup int64) *Service {
tbs := blockstore.NewTieredBstore(blockstore.NewAPIBlockstore(api), blockstore.NewMemory())
adtStore := adt.WrapStore(ctx, cbor.NewCborStore(tbs))

Expand All @@ -82,14 +115,17 @@ func NewService(ctx context.Context, db *gorm.DB, api API, maxWait time.Duration
adtStore: adtStore,
watchingMessages: NewSafeMap[uint, *watchMessage](),
maxWait: maxWait,
speedupMessages: make(chan *speedupMessage),
batchSpeedup: batchSpeedup,
shutdownFunc: func() {
cancel()
},
}
s.wg.Add(3)
s.wg.Add(4)
go s.runProcessor(ctx)
go s.runMessageChecker(ctx)
go s.runPendingChecker(ctx)
go s.runSpeedupWorker(ctx)
return s
}

Expand Down Expand Up @@ -611,6 +647,7 @@ loopParams:
Cid: CID{smsg.Cid()},
Extensions: exts,
Sectors: scount,
Nonce: smsg.Message.Nonce,
}
s.db.Create(msg)
messages = append(messages, msg)
Expand Down Expand Up @@ -703,7 +740,7 @@ func buildParams(l miner.SectorLocation, newExp abi.ChainEpoch, numbers []abi.Se
return &e2, cannotExtendSectors, sectorsInDecl, nil
}

func (s *Service) speedupRequest(ctx context.Context, id uint, mss *api.MessageSendSpec) error {
func (s *Service) speedupRequest(id uint, mss *api.MessageSendSpec) error {
var request Request
if err := s.db.Preload(clause.Associations).First(&request, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
Expand All @@ -715,16 +752,47 @@ func (s *Service) speedupRequest(ctx context.Context, id uint, mss *api.MessageS
return fmt.Errorf("request is not pending")
}

for _, msg := range request.Messages {
if msg.OnChain {
continue
sort.Slice(request.Messages, func(i, j int) bool {
return request.Messages[i].Nonce < request.Messages[j].Nonce
})

go func() {
for _, msg := range request.Messages {
if msg.OnChain {
continue
}
s.speedupMessages <- &speedupMessage{
msg: msg,
mss: mss,
}
}
// todo: need order by nonce?
if err := s.replaceMessage(ctx, msg.ID, mss); err != nil {
return fmt.Errorf("failed to replace message: %w", err)
}()
return nil
}

func (s *Service) runSpeedupWorker(ctx context.Context) {
defer s.wg.Done()
log.Info("starting speedup worker")

sem := make(chan struct{}, s.batchSpeedup)

for {
select {
case <-ctx.Done():
log.Info("context done, stopping speedup worker")
return
case sm := <-s.speedupMessages:
sem <- struct{}{}
go func(sm *speedupMessage) {
defer func() {
<-sem
}()
if err := s.replaceMessageAndWait(ctx, sm.msg.ID, sm.mss); err != nil {
log.Errorf("failed to replace message: %s", err)
}
}(sm)
}
}
return nil
}

func (s *Service) checkMessage(ctx context.Context, request *Request) error {
Expand Down Expand Up @@ -803,17 +871,20 @@ func (s *Service) watchMessage(ctx context.Context, id uint) {
sLog.Warn("message is already watching")
return
}
wm := newWatchMessage(&msg)
wm := newWatchMessage(s.db, msg.ID)
s.watchingMessages.Set(msg.ID, wm)
defer s.watchingMessages.Delete(msg.ID)
defer func() {
wm.Cancel()
s.watchingMessages.Delete(msg.ID)
}()

sLog.Info("watching message")

resultChan := make(chan *api.MsgLookup, 1)
errorChan := make(chan error, 1)

go func() {
receipt, err := s.api.StateWaitMsg(ctx, msg.Cid.Cid, 2*build.MessageConfidence, api.LookbackNoLimit, true)
receipt, err := s.api.StateWaitMsg(ctx, msg.Cid.Cid, build.MessageConfidence, api.LookbackNoLimit, true)
if err != nil {
errorChan <- err
return
Expand Down Expand Up @@ -872,7 +943,7 @@ func (s *Service) runPendingChecker(ctx context.Context) {
MaxFee: abi.TokenAmount(maxFee),
}
for _, id := range replaceMessages {
if err := s.replaceMessage(ctx, id, mss); err != nil {
if _, err := s.replaceMessage(ctx, id, mss); err != nil {
log.Errorf("failed to replace message: %s", err)
}
}
Expand All @@ -881,31 +952,52 @@ func (s *Service) runPendingChecker(ctx context.Context) {
}
}

func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageSendSpec) error {
func (s *Service) replaceMessageAndWait(ctx context.Context, id uint, mss *api.MessageSendSpec) error {
nid, err := s.replaceMessage(ctx, id, mss)
if err != nil {
return err
}

var msg Message
if err := s.db.Preload(clause.Associations).
Where("cid = ?", nid).
First(&msg).Error; err != nil {
return fmt.Errorf("failed to get message: %w", err)
}

if wm, ok := s.watchingMessages.Get(msg.ID); ok {
wm.Wait()
} else {
s.watchMessage(ctx, msg.ID)
}
return nil
}

func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageSendSpec) (*cid.Cid, error) {
var m Message
if err := s.db.Preload(clause.Associations).First(&m, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("message id not found in db: %d", id)
return nil, fmt.Errorf("message id not found in db: %d", id)
}
return fmt.Errorf("failed to get request: %w", err)
return nil, fmt.Errorf("failed to get request: %w", err)
}
sLog := log.With("request", m.RequestID, "id", id, "cid", m.Cid.String())
sLog.Info("replacing message")

// get the message from the chain
cm, err := s.api.ChainGetMessage(ctx, m.Cid.Cid)
if err != nil {
return fmt.Errorf("could not find referenced message: %w", err)
return nil, fmt.Errorf("could not find referenced message: %w", err)
}

ts, err := s.api.ChainHead(ctx)
if err != nil {
return fmt.Errorf("getting chain head: %w", err)
return nil, fmt.Errorf("getting chain head: %w", err)
}

pending, err := s.api.MpoolPending(ctx, ts.Key())
if err != nil {
return err
return nil, err
}

var found *types.SignedMessage
Expand All @@ -919,13 +1011,13 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS
// If the message is not found in the mpool, skip it and continue with the next one
if found == nil {
sLog.Warn("message not found in mpool, skipping")
return nil
return nil, nil
}
msg := found.Message

cfg, err := s.api.MpoolGetConfig(ctx)
if err != nil {
return fmt.Errorf("failed to lookup the message pool config: %w", err)
return nil, fmt.Errorf("failed to lookup the message pool config: %w", err)
}

defaultRBF := messagepool.ComputeRBF(msg.GasPremium, cfg.ReplaceByFeeRatio)
Expand All @@ -935,7 +1027,7 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS
msg.GasPremium = abi.NewTokenAmount(0)
ret, err := s.api.GasEstimateMessageGas(ctx, &msg, mss, types.EmptyTSK)
if err != nil {
return fmt.Errorf("failed to estimate gas values: %w", err)
return nil, fmt.Errorf("failed to estimate gas values: %w", err)
}
msg.GasPremium = big.Max(ret.GasPremium, defaultRBF)
msg.GasFeeCap = big.Max(ret.GasFeeCap, msg.GasPremium)
Expand All @@ -948,10 +1040,10 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS

smsg, err := s.api.WalletSignMessage(ctx, msg.From, &msg)
if err != nil {
return fmt.Errorf("failed to sign message: %w", err)
return nil, fmt.Errorf("failed to sign message: %w", err)
}

return s.db.Transaction(func(tx *gorm.DB) error {
return lo.ToPtr(smsg.Cid()), s.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Delete(&m).Error; err != nil {
return err
}
Expand All @@ -965,6 +1057,7 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS
Extensions: m.Extensions,
RequestID: m.RequestID,
Sectors: m.Sectors,
Nonce: smsg.Message.Nonce,
}

if err := tx.Create(newMsg).Error; err != nil {
Expand Down

0 comments on commit 8061ab7

Please sign in to comment.