diff --git a/http/validators.go b/http/validators.go index 8504c7c4..ee80c5ce 100644 --- a/http/validators.go +++ b/http/validators.go @@ -70,7 +70,8 @@ func (s *Service) indexChunkSize(ctx context.Context) int { // Validators provides the validators, with their balance and status, for a given state. // stateID can be a slot number or state root, or one of the special values "genesis", "head", "justified" or "finalized". // validatorIndices is a list of validators to restrict the returned values. If no validators are supplied no filter will be applied. -func (s *Service) Validators(ctx context.Context, stateID string, validatorIndices []phase0.ValidatorIndex) (map[phase0.ValidatorIndex]*api.Validator, error) { +// validatorStates is a list of validator states to restrict the returned values. If no states are supplied no filter will be applied. +func (s *Service) Validators(ctx context.Context, stateID string, validatorIndices []phase0.ValidatorIndex, validatorStates []api.ValidatorState) (map[phase0.ValidatorIndex]*api.Validator, error) { if stateID == "" { return nil, errors.New("no state ID specified") } @@ -80,7 +81,7 @@ func (s *Service) Validators(ctx context.Context, stateID string, validatorIndic } if len(validatorIndices) > s.indexChunkSize(ctx) { - return s.chunkedValidators(ctx, stateID, validatorIndices) + return s.chunkedValidators(ctx, stateID, validatorIndices, validatorStates) } url := fmt.Sprintf("/eth/v1/beacon/states/%s/validators", stateID) @@ -91,6 +92,17 @@ func (s *Service) Validators(ctx context.Context, stateID string, validatorIndic } url = fmt.Sprintf("%s?id=%s", url, strings.Join(ids, ",")) } + if len(validatorStates) != 0 { + states := make([]string, len(validatorStates)) + for i := range validatorStates { + states[i] = validatorStates[i].String() + } + if len(validatorIndices) != 0 { + url = fmt.Sprintf("%s&status=%s", url, strings.Join(states, ",")) + } else { + url = fmt.Sprintf("%s?status=%s", url, strings.Join(states, ",")) + } + } respBodyReader, err := s.get(ctx, url) if err != nil { @@ -168,7 +180,7 @@ func (s *Service) validatorsFromState(ctx context.Context, stateID string) (map[ } // chunkedValidators obtains the validators a chunk at a time. -func (s *Service) chunkedValidators(ctx context.Context, stateID string, validatorIndices []phase0.ValidatorIndex) (map[phase0.ValidatorIndex]*api.Validator, error) { +func (s *Service) chunkedValidators(ctx context.Context, stateID string, validatorIndices []phase0.ValidatorIndex, validatorStates []api.ValidatorState) (map[phase0.ValidatorIndex]*api.Validator, error) { res := make(map[phase0.ValidatorIndex]*api.Validator) indexChunkSize := s.indexChunkSize(ctx) for i := 0; i < len(validatorIndices); i += indexChunkSize { @@ -178,7 +190,7 @@ func (s *Service) chunkedValidators(ctx context.Context, stateID string, validat chunkEnd = len(validatorIndices) } chunk := validatorIndices[chunkStart:chunkEnd] - chunkRes, err := s.Validators(ctx, stateID, chunk) + chunkRes, err := s.Validators(ctx, stateID, chunk, validatorStates) if err != nil { return nil, errors.Wrap(err, "failed to obtain chunk") } diff --git a/http/validators_test.go b/http/validators_test.go index 3a62b319..57041d72 100644 --- a/http/validators_test.go +++ b/http/validators_test.go @@ -20,6 +20,7 @@ import ( "testing" client "github.com/attestantio/go-eth2-client" + v1 "github.com/attestantio/go-eth2-client/api/v1" "github.com/attestantio/go-eth2-client/http" "github.com/attestantio/go-eth2-client/spec/phase0" "github.com/stretchr/testify/require" @@ -34,6 +35,7 @@ func TestValidators(t *testing.T) { stateID string expectedErrorCode int validatorIndices []phase0.ValidatorIndex + validatorStates []v1.ValidatorState }{ { name: "Invalid", @@ -61,6 +63,14 @@ func TestValidators(t *testing.T) { name: "Justified", stateID: "justified", }, + { + name: "SomeStates", + stateID: "head", + validatorStates: []v1.ValidatorState{ + v1.ValidatorStateActiveOngoing, + v1.ValidatorStateExitedSlashed, + }, + }, { name: "ManyValidators", stateID: "head", @@ -78,7 +88,7 @@ func TestValidators(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - validators, err := service.(client.ValidatorsProvider).Validators(ctx, test.stateID, test.validatorIndices) + validators, err := service.(client.ValidatorsProvider).Validators(ctx, test.stateID, test.validatorIndices, test.validatorStates) if test.expectedErrorCode != 0 { require.Contains(t, err.Error(), fmt.Sprintf("%d", test.expectedErrorCode)) } else { diff --git a/mock/validators.go b/mock/validators.go index 4c9ac1d5..d9da02f3 100644 --- a/mock/validators.go +++ b/mock/validators.go @@ -24,6 +24,6 @@ import ( // stateID can be a slot number or state root, or one of the special values "genesis", "head", "justified" or "finalized". // validatorIndices is a list of validator indices to restrict the returned values. If no validators IDs are supplied no filter // will be applied. -func (s *Service) Validators(_ context.Context, _ string, _ []phase0.ValidatorIndex) (map[phase0.ValidatorIndex]*api.Validator, error) { +func (s *Service) Validators(_ context.Context, _ string, _ []phase0.ValidatorIndex, _ []api.ValidatorState) (map[phase0.ValidatorIndex]*api.Validator, error) { return map[phase0.ValidatorIndex]*api.Validator{}, nil } diff --git a/multi/validators.go b/multi/validators.go index 730d435b..b4c5dc58 100644 --- a/multi/validators.go +++ b/multi/validators.go @@ -24,15 +24,18 @@ import ( // Validators provides the validators, with their balance and status, for a given state. // stateID can be a slot number or state root, or one of the special values "genesis", "head", "justified" or "finalized". // validatorIndices is a list of validators to restrict the returned values. If no validators are supplied no filter will be applied. +// validatorStates is a list of validator states to restrict the returned values. If no validators states are supplied no filter +// will be applied. func (s *Service) Validators(ctx context.Context, stateID string, validatorIndices []phase0.ValidatorIndex, + validatorStates []api.ValidatorState, ) ( map[phase0.ValidatorIndex]*api.Validator, error, ) { res, err := s.doCall(ctx, func(ctx context.Context, client consensusclient.Service) (interface{}, error) { - block, err := client.(consensusclient.ValidatorsProvider).Validators(ctx, stateID, validatorIndices) + block, err := client.(consensusclient.ValidatorsProvider).Validators(ctx, stateID, validatorIndices, validatorStates) if err != nil { return nil, err } diff --git a/multi/validators_test.go b/multi/validators_test.go index 7ce846cc..f528705a 100644 --- a/multi/validators_test.go +++ b/multi/validators_test.go @@ -18,6 +18,7 @@ import ( "testing" consensusclient "github.com/attestantio/go-eth2-client" + v1 "github.com/attestantio/go-eth2-client/api/v1" "github.com/attestantio/go-eth2-client/mock" "github.com/attestantio/go-eth2-client/multi" "github.com/attestantio/go-eth2-client/spec/phase0" @@ -51,7 +52,7 @@ func TestValidators(t *testing.T) { require.NoError(t, err) for i := 0; i < 128; i++ { - res, err := multiClient.(consensusclient.ValidatorsProvider).Validators(ctx, "1", []phase0.ValidatorIndex{}) + res, err := multiClient.(consensusclient.ValidatorsProvider).Validators(ctx, "1", []phase0.ValidatorIndex{}, []v1.ValidatorState{}) require.NoError(t, err) require.NotNil(t, res) } diff --git a/service.go b/service.go index c6554dd9..ccd6230b 100644 --- a/service.go +++ b/service.go @@ -368,7 +368,9 @@ type ValidatorsProvider interface { // stateID can be a slot number or state root, or one of the special values "genesis", "head", "justified" or "finalized". // validatorIndices is a list of validator indices to restrict the returned values. If no validators IDs are supplied no filter // will be applied. - Validators(ctx context.Context, stateID string, validatorIndices []phase0.ValidatorIndex) (map[phase0.ValidatorIndex]*apiv1.Validator, error) + // validatorStates is a list of validator states to restrict the returned values. If no validators states are supplied no filter + // will be applied. + Validators(ctx context.Context, stateID string, validatorIndices []phase0.ValidatorIndex, validatorStates []apiv1.ValidatorState) (map[phase0.ValidatorIndex]*apiv1.Validator, error) // ValidatorsByPubKey provides the validators, with their balance and status, for a given state. // stateID can be a slot number or state root, or one of the special values "genesis", "head", "justified" or "finalized". diff --git a/testclients/erroring.go b/testclients/erroring.go index e3994671..9d58007f 100644 --- a/testclients/erroring.go +++ b/testclients/erroring.go @@ -580,7 +580,9 @@ func (s *Erroring) ValidatorBalances(ctx context.Context, stateID string, valida // stateID can be a slot number or state root, or one of the special values "genesis", "head", "justified" or "finalized". // validatorIndices is a list of validator indices to restrict the returned values. If no validators IDs are supplied no filter // will be applied. -func (s *Erroring) Validators(ctx context.Context, stateID string, validatorIndices []phase0.ValidatorIndex) (map[phase0.ValidatorIndex]*apiv1.Validator, error) { +// validatorStates is a list of validator states to restrict the returned values. If no validators states are supplied no filter +// will be applied. +func (s *Erroring) Validators(ctx context.Context, stateID string, validatorIndices []phase0.ValidatorIndex, validatorStates []apiv1.ValidatorState) (map[phase0.ValidatorIndex]*apiv1.Validator, error) { if err := s.maybeError(ctx); err != nil { return nil, err } @@ -588,7 +590,7 @@ func (s *Erroring) Validators(ctx context.Context, stateID string, validatorIndi if !isNext { return nil, fmt.Errorf("%s@%s does not support this call", s.next.Name(), s.next.Address()) } - return next.Validators(ctx, stateID, validatorIndices) + return next.Validators(ctx, stateID, validatorIndices, validatorStates) } // ValidatorsByPubKey provides the validators, with their balance and status, for a given state. diff --git a/testclients/sleepy.go b/testclients/sleepy.go index 20a323be..77ba87cb 100644 --- a/testclients/sleepy.go +++ b/testclients/sleepy.go @@ -395,13 +395,15 @@ func (s *Sleepy) ValidatorBalances(ctx context.Context, stateID string, validato // stateID can be a slot number or state root, or one of the special values "genesis", "head", "justified" or "finalized". // validatorIndices is a list of validator indices to restrict the returned values. If no validators IDs are supplied no filter // will be applied. -func (s *Sleepy) Validators(ctx context.Context, stateID string, validatorIndices []phase0.ValidatorIndex) (map[phase0.ValidatorIndex]*apiv1.Validator, error) { +// validatorStates is a list of validator states to restrict the returned values. If no validators states are supplied no filter +// will be applied. +func (s *Sleepy) Validators(ctx context.Context, stateID string, validatorIndices []phase0.ValidatorIndex, validatorStates []apiv1.ValidatorState) (map[phase0.ValidatorIndex]*apiv1.Validator, error) { s.sleep(ctx) next, isNext := s.next.(consensusclient.ValidatorsProvider) if !isNext { return nil, errors.New("next does not support this call") } - return next.Validators(ctx, stateID, validatorIndices) + return next.Validators(ctx, stateID, validatorIndices, validatorStates) } // ValidatorsByPubKey provides the validators, with their balance and status, for a given state.