Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support batch speedup of messages #23

Merged
merged 3 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (a *implAPI) speedup(w http.ResponseWriter, r *http.Request) {
}
}

err = a.srv.speedupRequest(r.Context(), uint(id), mss)
err = a.srv.speedupRequest(uint(id), mss)
if err != nil {
warpResponse(w, http.StatusBadRequest, nil, err)
return
Expand Down
7 changes: 6 additions & 1 deletion cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ var runCmd = &cli.Command{
Name: "max-wait",
Usage: "[Warning] specify the maximum time to wait for messages on chain, otherwise try to replace them, only use this if you know what you are doing",
},
&cli.Int64Flag{
Name: "batch-speedup",
Usage: "specify the number of messages to speed up in a single batch",
Value: 10,
},
&cli.BoolFlag{
Name: "debug",
Value: false,
Expand Down Expand Up @@ -172,7 +177,7 @@ var runCmd = &cli.Command{
authStatus = "enabled"
}

service := NewService(ctx, db, fullApi, cctx.Duration("max-wait"))
service := NewService(ctx, db, fullApi, cctx.Duration("max-wait"), cctx.Int64("batch-speedup"))
srv := &http.Server{
Handler: NewRouter(service, secret),
Addr: cctx.String("listen"),
Expand Down
1 change: 1 addition & 0 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ type Message struct {
ID uint `gorm:"primarykey"`
Cid CID `gorm:"index"` // The unique identifier of the message
Extensions []Extension2 // The list of extensions associated with the message
Nonce uint64
RequestID uint
OnChain bool
ExitCode exitcode.ExitCode
Expand Down
159 changes: 126 additions & 33 deletions service.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
"github.com/filecoin-project/lotus/chain/messagepool"
"github.com/filecoin-project/lotus/chain/types"
"github.com/filecoin-project/lotus/node/config"

"github.com/ipfs/go-cid"
cbor "github.com/ipfs/go-ipld-cbor"
"github.com/samber/lo"
"go.uber.org/zap"
Expand All @@ -41,23 +43,52 @@ const (
)

type watchMessage struct {
id uint
cid CID
started time.Time
cancelCh chan struct{}
id uint
db *gorm.DB
started time.Time
cancelOnce sync.Once
cancelCh chan struct{}
}

func newWatchMessage(m *Message) *watchMessage {
func newWatchMessage(db *gorm.DB, id uint) *watchMessage {
return &watchMessage{
id: m.ID,
cid: m.Cid,
db: db,
id: id,
started: time.Now(),
cancelCh: make(chan struct{}),
}
}

func (w *watchMessage) Cancel() {
close(w.cancelCh)
w.cancelOnce.Do(func() {
close(w.cancelCh)
})
}

func (w *watchMessage) Wait() {
tk := time.NewTicker(5 * time.Second)
defer tk.Stop()

for {
select {
case <-w.cancelCh:
return
case <-tk.C:
var msg Message
if err := w.db.First(&msg, w.id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return
} else {
log.Errorf("failed to get message: %s", err)
}
}
}
}
}

type speedupMessage struct {
msg *Message
mss *api.MessageSendSpec
}

type Service struct {
Expand All @@ -67,10 +98,12 @@ type Service struct {
maxWait time.Duration
wg sync.WaitGroup
watchingMessages *SafeMap[uint, *watchMessage]
speedupMessages chan *speedupMessage
batchSpeedup int64
shutdownFunc context.CancelFunc
}

func NewService(ctx context.Context, db *gorm.DB, api API, maxWait time.Duration) *Service {
func NewService(ctx context.Context, db *gorm.DB, api API, maxWait time.Duration, batchSpeedup int64) *Service {
tbs := blockstore.NewTieredBstore(blockstore.NewAPIBlockstore(api), blockstore.NewMemory())
adtStore := adt.WrapStore(ctx, cbor.NewCborStore(tbs))

Expand All @@ -82,14 +115,17 @@ func NewService(ctx context.Context, db *gorm.DB, api API, maxWait time.Duration
adtStore: adtStore,
watchingMessages: NewSafeMap[uint, *watchMessage](),
maxWait: maxWait,
speedupMessages: make(chan *speedupMessage),
batchSpeedup: batchSpeedup,
shutdownFunc: func() {
cancel()
},
}
s.wg.Add(3)
s.wg.Add(4)
go s.runProcessor(ctx)
go s.runMessageChecker(ctx)
go s.runPendingChecker(ctx)
go s.runSpeedupWorker(ctx)
return s
}

Expand Down Expand Up @@ -611,6 +647,7 @@ loopParams:
Cid: CID{smsg.Cid()},
Extensions: exts,
Sectors: scount,
Nonce: smsg.Message.Nonce,
}
s.db.Create(msg)
messages = append(messages, msg)
Expand Down Expand Up @@ -703,7 +740,7 @@ func buildParams(l miner.SectorLocation, newExp abi.ChainEpoch, numbers []abi.Se
return &e2, cannotExtendSectors, sectorsInDecl, nil
}

func (s *Service) speedupRequest(ctx context.Context, id uint, mss *api.MessageSendSpec) error {
func (s *Service) speedupRequest(id uint, mss *api.MessageSendSpec) error {
var request Request
if err := s.db.Preload(clause.Associations).First(&request, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
Expand All @@ -715,16 +752,47 @@ func (s *Service) speedupRequest(ctx context.Context, id uint, mss *api.MessageS
return fmt.Errorf("request is not pending")
}

for _, msg := range request.Messages {
if msg.OnChain {
continue
sort.Slice(request.Messages, func(i, j int) bool {
return request.Messages[i].Nonce < request.Messages[j].Nonce
})

go func() {
for _, msg := range request.Messages {
if msg.OnChain {
continue
}
s.speedupMessages <- &speedupMessage{
msg: msg,
mss: mss,
}
}
// todo: need order by nonce?
if err := s.replaceMessage(ctx, msg.ID, mss); err != nil {
return fmt.Errorf("failed to replace message: %w", err)
}()
return nil
}

func (s *Service) runSpeedupWorker(ctx context.Context) {
defer s.wg.Done()
log.Info("starting speedup worker")

sem := make(chan struct{}, s.batchSpeedup)

for {
select {
case <-ctx.Done():
log.Info("context done, stopping speedup worker")
return
case sm := <-s.speedupMessages:
sem <- struct{}{}
go func(sm *speedupMessage) {
defer func() {
<-sem
}()
if err := s.replaceMessageAndWait(ctx, sm.msg.ID, sm.mss); err != nil {
log.Errorf("failed to replace message: %s", err)
}
}(sm)
}
}
return nil
}

func (s *Service) checkMessage(ctx context.Context, request *Request) error {
Expand Down Expand Up @@ -803,17 +871,20 @@ func (s *Service) watchMessage(ctx context.Context, id uint) {
sLog.Warn("message is already watching")
return
}
wm := newWatchMessage(&msg)
wm := newWatchMessage(s.db, msg.ID)
s.watchingMessages.Set(msg.ID, wm)
defer s.watchingMessages.Delete(msg.ID)
defer func() {
wm.Cancel()
s.watchingMessages.Delete(msg.ID)
}()

sLog.Info("watching message")

resultChan := make(chan *api.MsgLookup, 1)
errorChan := make(chan error, 1)

go func() {
receipt, err := s.api.StateWaitMsg(ctx, msg.Cid.Cid, 2*build.MessageConfidence, api.LookbackNoLimit, true)
receipt, err := s.api.StateWaitMsg(ctx, msg.Cid.Cid, build.MessageConfidence, api.LookbackNoLimit, true)
if err != nil {
errorChan <- err
return
Expand Down Expand Up @@ -872,7 +943,7 @@ func (s *Service) runPendingChecker(ctx context.Context) {
MaxFee: abi.TokenAmount(maxFee),
}
for _, id := range replaceMessages {
if err := s.replaceMessage(ctx, id, mss); err != nil {
if _, err := s.replaceMessage(ctx, id, mss); err != nil {
log.Errorf("failed to replace message: %s", err)
}
}
Expand All @@ -881,31 +952,52 @@ func (s *Service) runPendingChecker(ctx context.Context) {
}
}

func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageSendSpec) error {
func (s *Service) replaceMessageAndWait(ctx context.Context, id uint, mss *api.MessageSendSpec) error {
nid, err := s.replaceMessage(ctx, id, mss)
if err != nil {
return err
}

var msg Message
if err := s.db.Preload(clause.Associations).
Where("cid = ?", nid).
First(&msg).Error; err != nil {
return fmt.Errorf("failed to get message: %w", err)
}

if wm, ok := s.watchingMessages.Get(msg.ID); ok {
wm.Wait()
} else {
s.watchMessage(ctx, msg.ID)
}
return nil
}

func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageSendSpec) (*cid.Cid, error) {
var m Message
if err := s.db.Preload(clause.Associations).First(&m, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("message id not found in db: %d", id)
return nil, fmt.Errorf("message id not found in db: %d", id)
}
return fmt.Errorf("failed to get request: %w", err)
return nil, fmt.Errorf("failed to get request: %w", err)
}
sLog := log.With("request", m.RequestID, "id", id, "cid", m.Cid.String())
sLog.Info("replacing message")

// get the message from the chain
cm, err := s.api.ChainGetMessage(ctx, m.Cid.Cid)
if err != nil {
return fmt.Errorf("could not find referenced message: %w", err)
return nil, fmt.Errorf("could not find referenced message: %w", err)
}

ts, err := s.api.ChainHead(ctx)
if err != nil {
return fmt.Errorf("getting chain head: %w", err)
return nil, fmt.Errorf("getting chain head: %w", err)
}

pending, err := s.api.MpoolPending(ctx, ts.Key())
if err != nil {
return err
return nil, err
}

var found *types.SignedMessage
Expand All @@ -919,13 +1011,13 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS
// If the message is not found in the mpool, skip it and continue with the next one
if found == nil {
sLog.Warn("message not found in mpool, skipping")
return nil
return nil, nil
}
msg := found.Message

cfg, err := s.api.MpoolGetConfig(ctx)
if err != nil {
return fmt.Errorf("failed to lookup the message pool config: %w", err)
return nil, fmt.Errorf("failed to lookup the message pool config: %w", err)
}

defaultRBF := messagepool.ComputeRBF(msg.GasPremium, cfg.ReplaceByFeeRatio)
Expand All @@ -935,7 +1027,7 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS
msg.GasPremium = abi.NewTokenAmount(0)
ret, err := s.api.GasEstimateMessageGas(ctx, &msg, mss, types.EmptyTSK)
if err != nil {
return fmt.Errorf("failed to estimate gas values: %w", err)
return nil, fmt.Errorf("failed to estimate gas values: %w", err)
}
msg.GasPremium = big.Max(ret.GasPremium, defaultRBF)
msg.GasFeeCap = big.Max(ret.GasFeeCap, msg.GasPremium)
Expand All @@ -948,10 +1040,10 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS

smsg, err := s.api.WalletSignMessage(ctx, msg.From, &msg)
if err != nil {
return fmt.Errorf("failed to sign message: %w", err)
return nil, fmt.Errorf("failed to sign message: %w", err)
}

return s.db.Transaction(func(tx *gorm.DB) error {
return lo.ToPtr(smsg.Cid()), s.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Delete(&m).Error; err != nil {
return err
}
Expand All @@ -965,6 +1057,7 @@ func (s *Service) replaceMessage(ctx context.Context, id uint, mss *api.MessageS
Extensions: m.Extensions,
RequestID: m.RequestID,
Sectors: m.Sectors,
Nonce: smsg.Message.Nonce,
}

if err := tx.Create(newMsg).Error; err != nil {
Expand Down
Loading