Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: deadlock when signing with the same validator in parallel #108

Merged
merged 10 commits into from
Oct 28, 2024
6 changes: 3 additions & 3 deletions cli/cmd/wallet/cmd/account/handler/handler_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,13 @@ func ValidateHighestValues(accountFlagValues CreateAccountFlagValues) error {
privateKeysCount := len(accountFlagValues.privateKeys)

if len(accountFlagValues.highestSources) != privateKeysCount {
return errors.Errorf("highest sources " + errorExplain)
return errors.Errorf("highest sources %v", errorExplain)
}
if len(accountFlagValues.highestTargets) != privateKeysCount {
return errors.Errorf("highest targets " + errorExplain)
return errors.Errorf("highest targets %v", errorExplain)
}
if len(accountFlagValues.highestProposals) != privateKeysCount {
return errors.Errorf("highest proposals " + errorExplain)
return errors.Errorf("highest proposals %v", errorExplain)
}
} else if accountFlagValues.accumulate {
if len(accountFlagValues.highestSources) != (accountFlagValues.index + 1) {
Expand Down
5 changes: 3 additions & 2 deletions signer/sign_attestation.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ func (signer *SimpleSigner) SignBeaconAttestation(attestation *phase0.Attestatio
}

// 2. lock for current account
signer.lock(account.ID(), "attestation")
val := signer.lock(account.ID(), "attestation")
val.Lock()
defer func() {
signer.unlock(account.ID(), "attestation")
val.Unlock()
}()
y0sher marked this conversation as resolved.
Show resolved Hide resolved

// 3. far future check
Expand Down
51 changes: 51 additions & 0 deletions signer/sign_attestation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,57 @@ func TestReferenceAttestation(t *testing.T) {
require.EqualValues(t, sig, actualSig)
}

// tested against a block and sig generated from https://github.com/prysmaticlabs/prysm/blob/master/shared/testutil/block.go#L357
func TestLockSameValidatorInParallel(t *testing.T) {
sk := _byteArray("2c083f2c8fc923fa2bd32a70ab72b4b46247e8c1f347adc30b2f8036a355086c")
pk := _byteArray("a9cf360aa15fb1d1d30ee2b578dc5884823c19661886ae8b892775ccb3bd96b7d7345569a2aa0b14e4d015c54a6a0c54")
domain := _byteArray32("0100000081509579e35e84020ad8751eca180b44df470332d3ad17fc6fd52459")

store := inmemStorage()
options := &eth2keymanager.KeyVaultOptions{}
options.SetStorage(store)
options.SetWalletType(core.NDWallet)
vault, err := eth2keymanager.NewKeyVault(options)
require.NoError(t, err)
wallet, err := vault.Wallet()
require.NoError(t, err)

k, err := core.NewHDKeyFromPrivateKey(sk, "")
require.NoError(t, err)
acc := wallets.NewValidatorAccount("1", k, nil, "", vault.Context)
require.NoError(t, err)
require.NoError(t, wallet.AddValidatorAccount(acc))

//// setup signer
signer := NewSimpleSigner(wallet, &prot.NoProtection{}, core.MainNetwork)

attestationDataByts := _byteArray("000000000000000000000000000000003a43a4bf26fb5947e809c1f24f7dc6857c8ac007e535d48e6e4eca2122fd776b0000000000000000000000000000000000000000000000000000000000000000000000000000000002000000000000003a43a4bf26fb5947e809c1f24f7dc6857c8ac007e535d48e6e4eca2122fd776b")

// decode attestation
attData := &phase0.AttestationData{}
require.NoError(t, attData.UnmarshalSSZ(attestationDataByts))

go func() {
_, _, err := signer.SignBeaconAttestation(attData, phase0.Domain{0}, pk)
require.NoError(t, err)

}()

ch := make(chan struct{})

go func() {
_, _, err := signer.SignBeaconAttestation(attData, domain, pk)
close(ch)
require.NoError(t, err)
}()

select {
case <-ch:
case <-time.After(200 * time.Millisecond):
t.Fatal("timeout")
}
}

func TestAttestationSlashingSignatures(t *testing.T) {
t.Run("valid attestation, sign using public key", func(t *testing.T) {
seed, _ := hex.DecodeString("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1fff")
Expand Down
5 changes: 3 additions & 2 deletions signer/sign_block.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ func (signer *SimpleSigner) SignBlock(block ssz.HashRoot, slot phase0.Slot, doma
}

// 2. lock for current account
signer.lock(account.ID(), "proposal")
defer signer.unlock(account.ID(), "proposal")
val := signer.lock(account.ID(), "proposal")
val.Lock()
defer val.Unlock()

// 3. far future check
if !IsValidFarFutureSlot(signer.network, slot) {
Expand Down
15 changes: 9 additions & 6 deletions signer/sign_sync_committee.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ func (signer *SimpleSigner) SignSyncCommittee(msgBlockRoot []byte, domain phase0
}

// 2. lock for current account
signer.lock(account.ID(), "sync_committee")
defer signer.unlock(account.ID(), "sync_committee")
val := signer.lock(account.ID(), "sync_committee")
val.Lock()
defer val.Unlock()

// 3. sign
sszRoot := SSZBytes(msgBlockRoot)
Expand Down Expand Up @@ -51,8 +52,9 @@ func (signer *SimpleSigner) SignSyncCommitteeSelectionData(data *altair.SyncAggr
}

// 2. lock for current account
signer.lock(account.ID(), "sync_committee_selection_data")
defer signer.unlock(account.ID(), "sync_committee_selection_data")
val := signer.lock(account.ID(), "sync_committee_selection_data")
val.Lock()
defer val.Unlock()

// 3. sign
if data == nil {
Expand Down Expand Up @@ -83,8 +85,9 @@ func (signer *SimpleSigner) SignSyncCommitteeContributionAndProof(contribAndProo
}

// 2. lock for current account
signer.lock(account.ID(), "sync_committee_selection_and_proof")
defer signer.unlock(account.ID(), "sync_committee_selection_and_proof")
val := signer.lock(account.ID(), "sync_committee_selection_and_proof")
val.Lock()
defer val.Unlock()

// 3. sign
if contribAndProof == nil {
Expand Down
16 changes: 3 additions & 13 deletions signer/validator_signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,16 @@ func NewSimpleSigner(wallet core.Wallet, slashingProtector core.SlashingProtecto
}

// lock locks signer
func (signer *SimpleSigner) lock(accountID uuid.UUID, operation string) {
func (signer *SimpleSigner) lock(accountID uuid.UUID, operation string) *sync.RWMutex {
signer.mapLock.Lock()
defer signer.mapLock.Unlock()

k := accountID.String() + "_" + operation
if val, ok := signer.signLocks[k]; ok {
val.Lock()
return val
} else {
signer.signLocks[k] = &sync.RWMutex{}
signer.signLocks[k].Lock()
}
}

func (signer *SimpleSigner) unlock(accountID uuid.UUID, operation string) {
signer.mapLock.RLock()
defer signer.mapLock.RUnlock()

k := accountID.String() + "_" + operation
if val, ok := signer.signLocks[k]; ok {
val.Unlock()
return signer.signLocks[k]
}
}

Expand Down
15 changes: 13 additions & 2 deletions wallets/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/hex"
"encoding/json"
"strings"
"sync"

"github.com/google/uuid"
"github.com/pkg/errors"
Expand All @@ -21,6 +22,7 @@ type HDAccount struct {
id uuid.UUID
validationKey *core.HDKey
withdrawalPubKey []byte
contextMtx sync.RWMutex
context *core.WalletContext
}

Expand Down Expand Up @@ -161,7 +163,7 @@ func (account *HDAccount) GetDepositData() (map[string]interface{}, error) {
depositData, root, err := eth1deposit.DepositData(
account.validationKey,
account.withdrawalPubKey,
account.context.Storage.Network(),
account.GetContext().Storage.Network(),
eth1deposit.MaxEffectiveBalanceInGwei,
)
if err != nil {
Expand All @@ -173,11 +175,20 @@ func (account *HDAccount) GetDepositData() (map[string]interface{}, error) {
"signature": strings.TrimPrefix(depositData.Signature.String(), "0x"),
"withdrawalCredentials": hex.EncodeToString(depositData.WithdrawalCredentials),
"depositDataRoot": hex.EncodeToString(root[:]),
"depositContractAddress": account.context.Storage.Network().DepositContractAddress(),
"depositContractAddress": account.GetContext().Storage.Network().DepositContractAddress(),
}, nil
}

// SetContext is the context setter
func (account *HDAccount) SetContext(ctx *core.WalletContext) {
account.contextMtx.Lock()
defer account.contextMtx.Unlock()
account.context = ctx
}

// SetContext is the context setter
olegshmuelov marked this conversation as resolved.
Show resolved Hide resolved
func (account *HDAccount) GetContext() *core.WalletContext {
account.contextMtx.RLock()
defer account.contextMtx.RUnlock()
return account.context
}
Loading