diff --git a/cli/msg.go b/cli/msg.go index 8b575b92..5100ed6a 100644 --- a/cli/msg.go +++ b/cli/msg.go @@ -430,30 +430,31 @@ var updateFilledMessageCmd = &cli.Command{ } defer closer() - var id string - if id = ctx.String("id"); len(id) > 0 { - } else if signedCidStr := ctx.String("signed_cid"); len(signedCidStr) > 0 { - signedCid, err := cid.Decode(signedCidStr) - if err != nil { - return err - } - msg, err := client.GetMessageBySignedCid(ctx.Context, signedCid) - if err != nil { - return err - } - id = msg.ID - } else if unsignedCidStr := ctx.String("unsigned_cid"); len(unsignedCidStr) > 0 { - unsignedCid, err := cid.Decode(unsignedCidStr) - if err != nil { - return err - } - msg, err := client.GetMessageByUnsignedCid(ctx.Context, unsignedCid) - if err != nil { - return err + id := ctx.String("id") + if len(id) == 0 { + if signedCidStr := ctx.String("signed_cid"); len(signedCidStr) > 0 { + signedCid, err := cid.Decode(signedCidStr) + if err != nil { + return err + } + msg, err := client.GetMessageBySignedCid(ctx.Context, signedCid) + if err != nil { + return err + } + id = msg.ID + } else if unsignedCidStr := ctx.String("unsigned_cid"); len(unsignedCidStr) > 0 { + unsignedCid, err := cid.Decode(unsignedCidStr) + if err != nil { + return err + } + msg, err := client.GetMessageByUnsignedCid(ctx.Context, unsignedCid) + if err != nil { + return err + } + id = msg.ID + } else { + return fmt.Errorf("value of query must be entered") } - id = msg.ID - } else { - return fmt.Errorf("value of query must be entered") } _, err = client.UpdateFilledMessageByID(ctx.Context, id) diff --git a/cli/send.go b/cli/send.go index 6d1c7e59..d8fa8549 100644 --- a/cli/send.go +++ b/cli/send.go @@ -61,11 +61,11 @@ var SendCmd = &cli.Command{ return fmt.Errorf("'send' expects two arguments, target and amount") } - client, close, err := getAPI(ctx) + client, closer, err := getAPI(ctx) if err != nil { return err } - defer close() + defer closer() var params types.QuickSendParams diff --git a/gateway/mock_gateway_client.go b/gateway/mock_gateway_client.go index 8387e2b4..580a8715 100644 --- a/gateway/mock_gateway_client.go +++ b/gateway/mock_gateway_client.go @@ -90,7 +90,7 @@ func (m *MockWalletProxy) WalletHas(ctx context.Context, addr address.Address, a return false, nil } -func (m *MockWalletProxy) WalletSign(ctx context.Context, addr address.Address, accounts []string, toSign []byte, meta types.MsgMeta) (*crypto.Signature, error) { +func (m *MockWalletProxy) WalletSign(ctx context.Context, addr address.Address, accounts []string, toSign []byte, _ types.MsgMeta) (*crypto.Signature, error) { has, err := m.WalletHas(ctx, addr, accounts) if err != nil { return nil, err diff --git a/integration_test/message_test.go b/integration_test/message_test.go index 214e9f50..4222be75 100644 --- a/integration_test/message_test.go +++ b/integration_test/message_test.go @@ -37,7 +37,7 @@ func TestHasMessageByUid(t *testing.T) { defer p.closer() t.Run("test has message by uid", func(t *testing.T) { - testHasMessageByUid(p.ctx, t, p.apiAdmin, p.apiSign, p.addrs) + testHasMessageByUid(p.ctx, t, p.apiAdmin, p.addrs) }) assert.NoError(t, p.ms.stop(p.ctx)) } @@ -98,10 +98,10 @@ func TestMessageAPI(t *testing.T) { testGetMessageByFromAndNonce(p.ctx, t, p.apiAdmin, p.apiSign, p.addrs) }) t.Run("test list message by from state", func(t *testing.T) { - testListMessageByFromState(p.ctx, t, p.apiAdmin, p.apiSign, p.addrs) + testListMessageByFromState(p.ctx, t, p.apiAdmin, p.addrs) }) t.Run("test list message by address", func(t *testing.T) { - testListMessageByAddress(p.ctx, t, p.apiAdmin, p.apiSign) + testListMessageByAddress(p.ctx, t, p.apiAdmin) }) t.Run("test list failed message", func(t *testing.T) { testListFailedMessage(p.ctx, t, p.apiAdmin, p.apiSign, p.addrs, p.blockDelay) @@ -113,7 +113,7 @@ func TestMessageAPI(t *testing.T) { testUpdateMessageStateByID(p.ctx, t, p.apiAdmin, p.apiSign, p.addrs, p.blockDelay) }) t.Run("test update all filled message", func(t *testing.T) { - testUpdateAllFilledMessage(p.ctx, t, p.apiAdmin, p.apiSign, p.addrs, p.blockDelay) + testUpdateAllFilledMessage(p.ctx, t, p.apiAdmin, p.addrs, p.blockDelay) }) t.Run("test replace message", func(t *testing.T) { testReplaceMessage(p.ctx, t, p.apiAdmin, p.apiSign, p.addrs, p.blockDelay) @@ -164,7 +164,7 @@ func testPushMessageWithID(ctx context.Context, t *testing.T, api, apiSign messa } } -func testHasMessageByUid(ctx context.Context, t *testing.T, api, apiSign messager.IMessager, addrs []address.Address) { +func testHasMessageByUid(ctx context.Context, t *testing.T, api messager.IMessager, addrs []address.Address) { msgs := genMessageWithAddress(addrs) for _, msg := range msgs { id, err := api.PushMessageWithId(ctx, msg.ID, &msg.Message, nil) @@ -330,7 +330,7 @@ func testListMessage(ctx context.Context, t *testing.T, api, apiSign messager.IM } } -func testListMessageByFromState(ctx context.Context, t *testing.T, api, apiSign messager.IMessager, addrs []address.Address) { +func testListMessageByFromState(ctx context.Context, t *testing.T, api messager.IMessager, addrs []address.Address) { // insert message genMessagesAndWait(ctx, t, api, addrs) genMessagesAndWait(ctx, t, api, addrs) @@ -403,7 +403,7 @@ func testListMessageByFromState(ctx context.Context, t *testing.T, api, apiSign } } -func testListMessageByAddress(ctx context.Context, t *testing.T, api, apiSign messager.IMessager) { +func testListMessageByAddress(ctx context.Context, t *testing.T, api messager.IMessager) { allMsgs, err := api.ListMessage(ctx, &types.MsgQueryParams{}) assert.NoError(t, err) msgIDs := make(map[address.Address][]string) @@ -552,7 +552,7 @@ func testUpdateMessageStateByID(ctx context.Context, t *testing.T, api, apiSign } } -func testUpdateAllFilledMessage(ctx context.Context, t *testing.T, api, apiSign messager.IMessager, addrs []address.Address, blockDelay time.Duration) { +func testUpdateAllFilledMessage(ctx context.Context, t *testing.T, api messager.IMessager, addrs []address.Address, blockDelay time.Duration) { msgs := genMessageWithAddress(addrs) for _, msg := range msgs { id, err := api.PushMessageWithId(ctx, msg.ID, &msg.Message, nil) @@ -805,8 +805,7 @@ func checkUnsignedMsg(t *testing.T, expect, actual *shared.Message) { assert.Equal(t, expect.Params, actual.Params) assert.Equal(t, testhelper.ResolveAddr(t, expect.From), actual.From) // todo: finish estimate gas - if actual.Nonce > 0 { - } else { + if actual.Nonce == 0 { assert.Equal(t, expect.GasLimit, actual.GasLimit) assert.Equal(t, expect.GasFeeCap, actual.GasFeeCap) assert.Equal(t, expect.GasPremium, actual.GasPremium) diff --git a/models/mysql/actor_cfg.go b/models/mysql/actor_cfg.go index c895a68d..252a7fa3 100644 --- a/models/mysql/actor_cfg.go +++ b/models/mysql/actor_cfg.go @@ -95,7 +95,7 @@ func (s *mysqlActorCfgRepo) SaveActorCfg(ctx context.Context, actorCfg *types.Ac func (s *mysqlActorCfgRepo) HasActorCfg(ctx context.Context, methodType *types.MethodType) (bool, error) { var count int64 - if err := s.DB.Table("actor_cfg").Where("code = ? and method = ?", mtypes.NewDBCid(methodType.Code), + if err := s.DB.WithContext(ctx).Table("actor_cfg").Where("code = ? and method = ?", mtypes.NewDBCid(methodType.Code), methodType.Method).Count(&count).Error; err != nil { return false, err } @@ -105,7 +105,7 @@ func (s *mysqlActorCfgRepo) HasActorCfg(ctx context.Context, methodType *types.M func (s *mysqlActorCfgRepo) GetActorCfgByMethodType(ctx context.Context, methodType *types.MethodType) (*types.ActorCfg, error) { var a mysqlActorCfg - if err := s.DB.Take(&a, "code = ? and method = ?", mtypes.NewDBCid(methodType.Code), methodType.Method).Error; err != nil { + if err := s.DB.WithContext(ctx).Take(&a, "code = ? and method = ?", mtypes.NewDBCid(methodType.Code), methodType.Method).Error; err != nil { return nil, err } @@ -114,7 +114,7 @@ func (s *mysqlActorCfgRepo) GetActorCfgByMethodType(ctx context.Context, methodT func (s *mysqlActorCfgRepo) GetActorCfgByID(ctx context.Context, id shared.UUID) (*types.ActorCfg, error) { var a mysqlActorCfg - if err := s.DB.Take(&a, "id = ?", id).Error; err != nil { + if err := s.DB.WithContext(ctx).Take(&a, "id = ?", id).Error; err != nil { return nil, err } @@ -123,7 +123,7 @@ func (s *mysqlActorCfgRepo) GetActorCfgByID(ctx context.Context, id shared.UUID) func (s *mysqlActorCfgRepo) ListActorCfg(ctx context.Context) ([]*types.ActorCfg, error) { var list []*mysqlActorCfg - if err := s.DB.Find(&list).Error; err != nil { + if err := s.DB.WithContext(ctx).Find(&list).Error; err != nil { return nil, err } @@ -136,11 +136,11 @@ func (s *mysqlActorCfgRepo) ListActorCfg(ctx context.Context) ([]*types.ActorCfg } func (s *mysqlActorCfgRepo) DelActorCfgByMethodType(ctx context.Context, methodType *types.MethodType) error { - return s.DB.Delete(mysqlActorCfg{}, "code = ? and method = ?", mtypes.NewDBCid(methodType.Code), methodType.Method).Error + return s.DB.WithContext(ctx).Delete(mysqlActorCfg{}, "code = ? and method = ?", mtypes.NewDBCid(methodType.Code), methodType.Method).Error } func (s *mysqlActorCfgRepo) DelActorCfgById(ctx context.Context, id shared.UUID) error { - return s.DB.Delete(mysqlActorCfg{}, "id = ?", id).Error + return s.DB.WithContext(ctx).Delete(mysqlActorCfg{}, "id = ?", id).Error } func (s *mysqlActorCfgRepo) UpdateSelectSpecById(ctx context.Context, id shared.UUID, spec *types.ChangeGasSpecParams) error { @@ -168,5 +168,5 @@ func (s *mysqlActorCfgRepo) UpdateSelectSpecById(ctx context.Context, id shared. updateColumns["updated_at"] = time.Now() - return s.DB.Model((*mysqlActorCfg)(nil)).Where("id = ?", id).UpdateColumns(updateColumns).Error + return s.DB.WithContext(ctx).Model((*mysqlActorCfg)(nil)).Where("id = ?", id).UpdateColumns(updateColumns).Error } diff --git a/models/mysql/address.go b/models/mysql/address.go index cac98215..61f4b870 100644 --- a/models/mysql/address.go +++ b/models/mysql/address.go @@ -90,12 +90,12 @@ func newMysqlAddressRepo(db *gorm.DB) *mysqlAddressRepo { } func (s mysqlAddressRepo) SaveAddress(ctx context.Context, a *types.Address) error { - return s.DB.Save(fromAddress(a)).Error + return s.DB.WithContext(ctx).Save(fromAddress(a)).Error } func (s mysqlAddressRepo) GetAddress(ctx context.Context, addr address.Address) (*types.Address, error) { var a mysqlAddress - if err := s.DB.Take(&a, "addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted).Error; err != nil { + if err := s.DB.WithContext(ctx).Take(&a, "addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted).Error; err != nil { return nil, err } @@ -104,7 +104,7 @@ func (s mysqlAddressRepo) GetAddress(ctx context.Context, addr address.Address) func (s mysqlAddressRepo) GetAddressByID(ctx context.Context, id shared.UUID) (*types.Address, error) { var a mysqlAddress - if err := s.DB.Where("id = ? and is_deleted = ?", id, repo.NotDeleted).First(&a).Error; err != nil { + if err := s.DB.WithContext(ctx).Where("id = ? and is_deleted = ?", id, repo.NotDeleted).First(&a).Error; err != nil { return nil, err } @@ -113,7 +113,7 @@ func (s mysqlAddressRepo) GetAddressByID(ctx context.Context, id shared.UUID) (* func (s mysqlAddressRepo) GetOneRecord(ctx context.Context, addr address.Address) (*types.Address, error) { var a mysqlAddress - if err := s.DB.Take(&a, "addr = ?", addr.String()).Error; err != nil { + if err := s.DB.WithContext(ctx).Take(&a, "addr = ?", addr.String()).Error; err != nil { return nil, err } @@ -122,7 +122,7 @@ func (s mysqlAddressRepo) GetOneRecord(ctx context.Context, addr address.Address func (s mysqlAddressRepo) HasAddress(ctx context.Context, addr address.Address) (bool, error) { var count int64 - if err := s.DB.Model(&mysqlAddress{}).Where("addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted). + if err := s.DB.WithContext(ctx).Model(&mysqlAddress{}).Where("addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted). Count(&count).Error; err != nil { return false, err } @@ -131,7 +131,7 @@ func (s mysqlAddressRepo) HasAddress(ctx context.Context, addr address.Address) func (s mysqlAddressRepo) ListAddress(ctx context.Context) ([]*types.Address, error) { var list []*mysqlAddress - if err := s.DB.Find(&list, "is_deleted = ?", repo.NotDeleted).Error; err != nil { + if err := s.DB.WithContext(ctx).Find(&list, "is_deleted = ?", repo.NotDeleted).Error; err != nil { return nil, err } @@ -149,7 +149,7 @@ func (s mysqlAddressRepo) ListAddress(ctx context.Context) ([]*types.Address, er func (s mysqlAddressRepo) ListActiveAddress(ctx context.Context) ([]*types.Address, error) { var list []*mysqlAddress - if err := s.DB.Find(&list, "is_deleted = ? and state = ?", repo.NotDeleted, types.AddressStateAlive).Error; err != nil { + if err := s.DB.WithContext(ctx).Find(&list, "is_deleted = ? and state = ?", repo.NotDeleted, types.AddressStateAlive).Error; err != nil { return nil, err } @@ -166,22 +166,22 @@ func (s mysqlAddressRepo) ListActiveAddress(ctx context.Context) ([]*types.Addre } func (s mysqlAddressRepo) DelAddress(ctx context.Context, addr address.Address) error { - return s.DB.Model((*mysqlAddress)(nil)).Where("addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted). + return s.DB.WithContext(ctx).Model((*mysqlAddress)(nil)).Where("addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted). UpdateColumns(map[string]interface{}{"is_deleted": repo.Deleted, "state": types.AddressStateRemoved, "updated_at": time.Now()}).Error } func (s mysqlAddressRepo) UpdateNonce(ctx context.Context, addr address.Address, nonce uint64) error { - return s.DB.Model(&mysqlAddress{}).Where("addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted). + return s.DB.WithContext(ctx).Model(&mysqlAddress{}).Where("addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted). UpdateColumns(map[string]interface{}{"nonce": nonce, "updated_at": time.Now()}).Error } func (s mysqlAddressRepo) UpdateState(ctx context.Context, addr address.Address, state types.AddressState) error { - return s.DB.Model(&mysqlAddress{}).Where("addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted). + return s.DB.WithContext(ctx).Model(&mysqlAddress{}).Where("addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted). UpdateColumns(map[string]interface{}{"state": state, "updated_at": time.Now()}).Error } func (s mysqlAddressRepo) UpdateSelectMsgNum(ctx context.Context, addr address.Address, num uint64) error { - return s.DB.Model((*mysqlAddress)(nil)).Where("addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted). + return s.DB.WithContext(ctx).Model((*mysqlAddress)(nil)).Where("addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted). UpdateColumns(map[string]interface{}{"sel_msg_num": num, "updated_at": time.Now()}).Error } @@ -208,5 +208,5 @@ func (s mysqlAddressRepo) UpdateFeeParams(ctx context.Context, addr address.Addr updateColumns["updated_at"] = time.Now() - return s.DB.Model((*mysqlAddress)(nil)).Where("addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted).UpdateColumns(updateColumns).Error + return s.DB.WithContext(ctx).Model((*mysqlAddress)(nil)).Where("addr = ? and is_deleted = ?", addr.String(), repo.NotDeleted).UpdateColumns(updateColumns).Error } diff --git a/models/mysql/shared_params.go b/models/mysql/shared_params.go index 1989687d..11fbb8b0 100644 --- a/models/mysql/shared_params.go +++ b/models/mysql/shared_params.go @@ -62,7 +62,7 @@ func newMysqlSharedParamsRepo(db *gorm.DB) mysqlSharedParamsRepo { func (s mysqlSharedParamsRepo) GetSharedParams(ctx context.Context) (*types.SharedSpec, error) { var ssp mysqlSharedParams - if err := s.DB.Take(&ssp).Error; err != nil { + if err := s.DB.WithContext(ctx).Take(&ssp).Error; err != nil { return nil, err } return ssp.SharedParams(), nil @@ -72,9 +72,9 @@ func (s mysqlSharedParamsRepo) SetSharedParams(ctx context.Context, params *type var ssp mysqlSharedParams // make sure ID is 1 params.ID = 1 - if err := s.DB.Where("id = ?", 1).Take(&ssp).Error; err != nil { + if err := s.DB.WithContext(ctx).Where("id = ?", 1).Take(&ssp).Error; err != nil { if err == gorm.ErrRecordNotFound { - if err := s.DB.Save(fromSharedParams(*params)).Error; err != nil { + if err := s.DB.WithContext(ctx).Save(fromSharedParams(*params)).Error; err != nil { return 0, err } return params.ID, nil @@ -82,7 +82,7 @@ func (s mysqlSharedParamsRepo) SetSharedParams(ctx context.Context, params *type return 0, err } - if err := s.DB.Save(fromSharedParams(*params)).Error; err != nil { + if err := s.DB.WithContext(ctx).Save(fromSharedParams(*params)).Error; err != nil { return 0, err } diff --git a/models/sqlite/actor_cfg.go b/models/sqlite/actor_cfg.go index 35972abc..68644a4d 100644 --- a/models/sqlite/actor_cfg.go +++ b/models/sqlite/actor_cfg.go @@ -89,12 +89,12 @@ func (s *sqliteActorCfgRepo) SaveActorCfg(ctx context.Context, actorCfg *types.A if actorCfg.Code == cid.Undef { return errors.New("code cid is undefined") } - return s.DB.Save(fromActorCfg(actorCfg)).Error + return s.DB.WithContext(ctx).Save(fromActorCfg(actorCfg)).Error } func (s *sqliteActorCfgRepo) HasActorCfg(ctx context.Context, methodType *types.MethodType) (bool, error) { var count int64 - if err := s.DB.Table("actor_cfg").Where("code = ? and method = ?", mtypes.NewDBCid(methodType.Code), + if err := s.DB.WithContext(ctx).Table("actor_cfg").Where("code = ? and method = ?", mtypes.NewDBCid(methodType.Code), methodType.Method).Count(&count).Error; err != nil { return false, err } @@ -104,7 +104,7 @@ func (s *sqliteActorCfgRepo) HasActorCfg(ctx context.Context, methodType *types. func (s *sqliteActorCfgRepo) GetActorCfgByMethodType(ctx context.Context, methodType *types.MethodType) (*types.ActorCfg, error) { var a sqliteActorCfg - if err := s.DB.Take(&a, "code = ? and method = ?", mtypes.DBCid(methodType.Code), sqliteUint64(methodType.Method)).Error; err != nil { + if err := s.DB.WithContext(ctx).Take(&a, "code = ? and method = ?", mtypes.DBCid(methodType.Code), sqliteUint64(methodType.Method)).Error; err != nil { return nil, err } @@ -113,7 +113,7 @@ func (s *sqliteActorCfgRepo) GetActorCfgByMethodType(ctx context.Context, method func (s *sqliteActorCfgRepo) GetActorCfgByID(ctx context.Context, id shared.UUID) (*types.ActorCfg, error) { var a sqliteActorCfg - if err := s.DB.Take(&a, "id = ?", id).Error; err != nil { + if err := s.DB.WithContext(ctx).Take(&a, "id = ?", id).Error; err != nil { return nil, err } @@ -122,7 +122,7 @@ func (s *sqliteActorCfgRepo) GetActorCfgByID(ctx context.Context, id shared.UUID func (s *sqliteActorCfgRepo) ListActorCfg(ctx context.Context) ([]*types.ActorCfg, error) { var list []*sqliteActorCfg - if err := s.DB.Find(&list).Error; err != nil { + if err := s.DB.WithContext(ctx).Find(&list).Error; err != nil { return nil, err } @@ -135,11 +135,11 @@ func (s *sqliteActorCfgRepo) ListActorCfg(ctx context.Context) ([]*types.ActorCf } func (s *sqliteActorCfgRepo) DelActorCfgByMethodType(ctx context.Context, methodType *types.MethodType) error { - return s.DB.Delete(sqliteActorCfg{}, "code = ? and method = ?", mtypes.DBCid(methodType.Code), sqliteUint64(methodType.Method)).Error + return s.DB.WithContext(ctx).Delete(sqliteActorCfg{}, "code = ? and method = ?", mtypes.DBCid(methodType.Code), sqliteUint64(methodType.Method)).Error } func (s *sqliteActorCfgRepo) DelActorCfgById(ctx context.Context, id shared.UUID) error { - return s.DB.Delete(sqliteActorCfg{}, "id = ?", id).Error + return s.DB.WithContext(ctx).Delete(sqliteActorCfg{}, "id = ?", id).Error } func (s *sqliteActorCfgRepo) UpdateSelectSpecById(ctx context.Context, id shared.UUID, spec *types.ChangeGasSpecParams) error { @@ -167,5 +167,5 @@ func (s *sqliteActorCfgRepo) UpdateSelectSpecById(ctx context.Context, id shared updateColumns["updated_at"] = time.Now() - return s.DB.Model((*sqliteActorCfg)(nil)).Where("id = ?", id).UpdateColumns(updateColumns).Error + return s.DB.WithContext(ctx).Model((*sqliteActorCfg)(nil)).Where("id = ?", id).UpdateColumns(updateColumns).Error } diff --git a/models/sqlite/address.go b/models/sqlite/address.go index 4aa5991c..b05b8840 100644 --- a/models/sqlite/address.go +++ b/models/sqlite/address.go @@ -89,12 +89,12 @@ func newSqliteAddressRepo(db *gorm.DB) *sqliteAddressRepo { } func (s sqliteAddressRepo) SaveAddress(ctx context.Context, addr *types.Address) error { - return s.DB.Save(fromAddress(addr)).Error + return s.DB.WithContext(ctx).Save(fromAddress(addr)).Error } func (s sqliteAddressRepo) GetAddress(ctx context.Context, addr address.Address) (*types.Address, error) { var a sqliteAddress - if err := s.DB.Take(&a, "addr = ? and is_deleted = -1", addr.String()).Error; err != nil { + if err := s.DB.WithContext(ctx).Take(&a, "addr = ? and is_deleted = -1", addr.String()).Error; err != nil { return nil, err } @@ -103,7 +103,7 @@ func (s sqliteAddressRepo) GetAddress(ctx context.Context, addr address.Address) func (s sqliteAddressRepo) GetAddressByID(ctx context.Context, id shared.UUID) (*types.Address, error) { var a sqliteAddress - if err := s.DB.Where("id = ? and is_deleted = -1", id).First(&a).Error; err != nil { + if err := s.DB.WithContext(ctx).Where("id = ? and is_deleted = -1", id).First(&a).Error; err != nil { return nil, err } @@ -112,7 +112,7 @@ func (s sqliteAddressRepo) GetAddressByID(ctx context.Context, id shared.UUID) ( func (s sqliteAddressRepo) GetOneRecord(ctx context.Context, addr address.Address) (*types.Address, error) { var a sqliteAddress - if err := s.DB.Take(&a, "addr = ?", addr.String()).Error; err != nil { + if err := s.DB.WithContext(ctx).Take(&a, "addr = ?", addr.String()).Error; err != nil { return nil, err } @@ -121,7 +121,7 @@ func (s sqliteAddressRepo) GetOneRecord(ctx context.Context, addr address.Addres func (s sqliteAddressRepo) HasAddress(ctx context.Context, addr address.Address) (bool, error) { var count int64 - if err := s.DB.Model((*sqliteAddress)(nil)). + if err := s.DB.WithContext(ctx).Model((*sqliteAddress)(nil)). Where("addr = ? and is_deleted = -1", addr.String()).Count(&count).Error; err != nil { return false, err } @@ -130,7 +130,7 @@ func (s sqliteAddressRepo) HasAddress(ctx context.Context, addr address.Address) func (s sqliteAddressRepo) ListAddress(ctx context.Context) ([]*types.Address, error) { var list []*sqliteAddress - if err := s.DB.Find(&list, "is_deleted = ?", -1).Error; err != nil { + if err := s.DB.WithContext(ctx).Find(&list, "is_deleted = ?", -1).Error; err != nil { return nil, err } @@ -148,7 +148,7 @@ func (s sqliteAddressRepo) ListAddress(ctx context.Context) ([]*types.Address, e func (s sqliteAddressRepo) ListActiveAddress(ctx context.Context) ([]*types.Address, error) { var list []*sqliteAddress - if err := s.DB.Find(&list, "is_deleted = ? and state = ?", -1, types.AddressStateAlive).Error; err != nil { + if err := s.DB.WithContext(ctx).Find(&list, "is_deleted = ? and state = ?", -1, types.AddressStateAlive).Error; err != nil { return nil, err } @@ -165,17 +165,17 @@ func (s sqliteAddressRepo) ListActiveAddress(ctx context.Context) ([]*types.Addr } func (s sqliteAddressRepo) UpdateNonce(ctx context.Context, addr address.Address, nonce uint64) error { - return s.DB.Model((*sqliteAddress)(nil)).Where("addr = ? and is_deleted = -1", addr.String()). + return s.DB.WithContext(ctx).Model((*sqliteAddress)(nil)).Where("addr = ? and is_deleted = -1", addr.String()). UpdateColumns(map[string]interface{}{"nonce": nonce, "updated_at": time.Now()}).Error } func (s sqliteAddressRepo) UpdateState(ctx context.Context, addr address.Address, state types.AddressState) error { - return s.DB.Model((*sqliteAddress)(nil)).Where("addr = ? and is_deleted = -1", addr.String()). + return s.DB.WithContext(ctx).Model((*sqliteAddress)(nil)).Where("addr = ? and is_deleted = -1", addr.String()). UpdateColumns(map[string]interface{}{"state": state, "updated_at": time.Now()}).Error } func (s sqliteAddressRepo) UpdateSelectMsgNum(ctx context.Context, addr address.Address, num uint64) error { - return s.DB.Model((*sqliteAddress)(nil)).Where("addr = ? and is_deleted = -1", addr.String()). + return s.DB.WithContext(ctx).Model((*sqliteAddress)(nil)).Where("addr = ? and is_deleted = -1", addr.String()). UpdateColumns(map[string]interface{}{"sel_msg_num": num, "updated_at": time.Now()}).Error } @@ -202,11 +202,11 @@ func (s sqliteAddressRepo) UpdateFeeParams(ctx context.Context, addr address.Add updateColumns["updated_at"] = time.Now() - return s.DB.Model((*sqliteAddress)(nil)).Where("addr = ? and is_deleted = -1", addr.String()).UpdateColumns(updateColumns).Error + return s.DB.WithContext(ctx).Model((*sqliteAddress)(nil)).Where("addr = ? and is_deleted = -1", addr.String()).UpdateColumns(updateColumns).Error } func (s sqliteAddressRepo) DelAddress(ctx context.Context, addr address.Address) error { - return s.DB.Model((*sqliteAddress)(nil)).Where("addr = ? and is_deleted = -1", addr.String()). + return s.DB.WithContext(ctx).Model((*sqliteAddress)(nil)).Where("addr = ? and is_deleted = -1", addr.String()). UpdateColumns(map[string]interface{}{"is_deleted": repo.Deleted, "state": types.AddressStateRemoved, "updated_at": time.Now()}).Error } diff --git a/models/sqlite/shared_params.go b/models/sqlite/shared_params.go index cdbba92a..a4fe54ab 100644 --- a/models/sqlite/shared_params.go +++ b/models/sqlite/shared_params.go @@ -62,7 +62,7 @@ func newSqliteSharedParamsRepo(db *gorm.DB) sqliteSharedParamsRepo { func (s sqliteSharedParamsRepo) GetSharedParams(ctx context.Context) (*types.SharedSpec, error) { var ssp sqliteSharedParams - if err := s.DB.Take(&ssp).Error; err != nil { + if err := s.DB.WithContext(ctx).Take(&ssp).Error; err != nil { return nil, err } return ssp.SharedParams(), nil @@ -72,9 +72,9 @@ func (s sqliteSharedParamsRepo) SetSharedParams(ctx context.Context, params *typ var ssp sqliteSharedParams // make sure ID is 1 params.ID = 1 - if err := s.DB.Where("id = ?", 1).Take(&ssp).Error; err != nil { + if err := s.DB.WithContext(ctx).Where("id = ?", 1).Take(&ssp).Error; err != nil { if err == gorm.ErrRecordNotFound { - if err := s.DB.Save(fromSharedParams(*params)).Error; err != nil { + if err := s.DB.WithContext(ctx).Save(fromSharedParams(*params)).Error; err != nil { return 0, err } return params.ID, nil @@ -82,7 +82,7 @@ func (s sqliteSharedParamsRepo) SetSharedParams(ctx context.Context, params *typ return 0, err } - if err := s.DB.Save(fromSharedParams(*params)).Error; err != nil { + if err := s.DB.WithContext(ctx).Save(fromSharedParams(*params)).Error; err != nil { return 0, err } diff --git a/service/message_service_test.go b/service/message_service_test.go index df19fa62..b58bff71 100644 --- a/service/message_service_test.go +++ b/service/message_service_test.go @@ -443,7 +443,7 @@ func TestMessageService_ProcessNewHead(t *testing.T) { full := v1Mock.NewMockFullNode(gomock.NewController(t)) ms.nodeClient = full - full.EXPECT().ChainGetTipSet(gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(func(arg0 context.Context, arg1 shared.TipSetKey) (*shared.TipSet, error) { + full.EXPECT().ChainGetTipSet(gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(func(_ context.Context, arg1 shared.TipSetKey) (*shared.TipSet, error) { for _, ts := range tipSets { if ts.Key().Equals(arg1) { return ts, nil diff --git a/testhelper/mock_full_node.go b/testhelper/mock_full_node.go index 3d689b10..46009941 100644 --- a/testhelper/mock_full_node.go +++ b/testhelper/mock_full_node.go @@ -357,28 +357,28 @@ func (f *MockFullNode) blockProvider() (*types.BlockHeader, error) { //// full api //// -func (f *MockFullNode) StateAccountKey(ctx context.Context, addr address.Address, tsk types.TipSetKey) (address.Address, error) { +func (f *MockFullNode) StateAccountKey(_ context.Context, addr address.Address, tsk types.TipSetKey) (address.Address, error) { if addr.Protocol() != address.ID { return addr, nil } return ResolveIDAddr(addr) } -func (f *MockFullNode) StateNetworkName(ctx context.Context) (types.NetworkName, error) { +func (f *MockFullNode) StateNetworkName(_ context.Context) (types.NetworkName, error) { return types.NetworkNameMain, nil } -func (f *MockFullNode) StateNetworkVersion(arg0 context.Context, arg1 types.TipSetKey) (network.Version, error) { +func (f *MockFullNode) StateNetworkVersion(_ context.Context, _ types.TipSetKey) (network.Version, error) { return network.Version17, nil } -func (f *MockFullNode) StateGetNetworkParams(ctx context.Context) (*types.NetworkParams, error) { +func (f *MockFullNode) StateGetNetworkParams(_ context.Context) (*types.NetworkParams, error) { return &types.NetworkParams{ NetworkName: types.NetworkNameMain, BlockDelaySecs: uint64(f.blockDelay / time.Second), }, nil } -func (f *MockFullNode) ChainGetParentMessages(ctx context.Context, bcid cid.Cid) ([]types.MessageCID, error) { +func (f *MockFullNode) ChainGetParentMessages(_ context.Context, bcid cid.Cid) ([]types.MessageCID, error) { f.l.Lock() defer f.l.Unlock() blkInfo, ok := f.blockInfos[bcid] @@ -400,7 +400,7 @@ func (f *MockFullNode) ChainGetParentMessages(ctx context.Context, bcid cid.Cid) return msgCid, nil } -func (f *MockFullNode) ChainGetParentReceipts(ctx context.Context, bcid cid.Cid) ([]*types.MessageReceipt, error) { +func (f *MockFullNode) ChainGetParentReceipts(_ context.Context, bcid cid.Cid) ([]*types.MessageReceipt, error) { f.l.Lock() defer f.l.Unlock() blkInfo, ok := f.blockInfos[bcid] @@ -419,7 +419,7 @@ func (f *MockFullNode) ChainGetParentReceipts(ctx context.Context, bcid cid.Cid) return receipts, nil } -func (f *MockFullNode) ChainGetTipSet(ctx context.Context, key types.TipSetKey) (*types.TipSet, error) { +func (f *MockFullNode) ChainGetTipSet(_ context.Context, key types.TipSetKey) (*types.TipSet, error) { f.l.Lock() defer f.l.Unlock() @@ -456,7 +456,7 @@ func (f *MockFullNode) ChainList(ctx context.Context, tsKey types.TipSetKey, cou return keys, nil } -func (f *MockFullNode) ChainGetMessagesInTipset(ctx context.Context, key types.TipSetKey) ([]types.MessageCID, error) { +func (f *MockFullNode) ChainGetMessagesInTipset(_ context.Context, key types.TipSetKey) ([]types.MessageCID, error) { f.l.Lock() defer f.l.Unlock() _, ok := f.ts[key] @@ -479,14 +479,14 @@ func (f *MockFullNode) ChainGetMessagesInTipset(ctx context.Context, key types.T return msgs, nil } -func (f *MockFullNode) ChainHead(ctx context.Context) (*types.TipSet, error) { +func (f *MockFullNode) ChainHead(_ context.Context) (*types.TipSet, error) { f.l.Lock() defer f.l.Unlock() return f.currTS, nil } -func (f *MockFullNode) StateGetActor(ctx context.Context, addr address.Address, tsk types.TipSetKey) (*types.Actor, error) { +func (f *MockFullNode) StateGetActor(_ context.Context, addr address.Address, tsk types.TipSetKey) (*types.Actor, error) { f.l.Lock() defer f.l.Unlock() @@ -528,7 +528,7 @@ func (f *MockFullNode) GasBatchEstimateMessageGas(ctx context.Context, estimateM return res, nil } -func (f *MockFullNode) GasEstimateMessageGas(ctx context.Context, msg *types.Message, spec *types.MessageSendSpec, tsk types.TipSetKey) (*types.Message, error) { +func (f *MockFullNode) GasEstimateMessageGas(_ context.Context, msg *types.Message, spec *types.MessageSendSpec, tsk types.TipSetKey) (*types.Message, error) { err := estimateGasLimit(msg, spec) if err != nil { return nil, err @@ -577,7 +577,7 @@ func estimateGasLimit(msg *types.Message, spec *types.MessageSendSpec) error { return nil } -func (f *MockFullNode) MpoolBatchPushUntrusted(ctx context.Context, smsgs []*types.SignedMessage) ([]cid.Cid, error) { +func (f *MockFullNode) MpoolBatchPushUntrusted(_ context.Context, smsgs []*types.SignedMessage) ([]cid.Cid, error) { f.l.Lock() defer f.l.Unlock() cids := make([]cid.Cid, 0, len(smsgs)) @@ -605,7 +605,7 @@ func (f *MockFullNode) MpoolGetConfig(_ context.Context) (*types.MpoolConfig, er }, nil } -func (f *MockFullNode) StateSearchMsg(ctx context.Context, from types.TipSetKey, msgCid cid.Cid, limit abi.ChainEpoch, allowReplaced bool) (*types.MsgLookup, error) { +func (f *MockFullNode) StateSearchMsg(_ context.Context, from types.TipSetKey, msgCid cid.Cid, limit abi.ChainEpoch, allowReplaced bool) (*types.MsgLookup, error) { f.l.Lock() defer f.l.Unlock() _, ok := f.chainMsgs[msgCid]