Skip to content

Commit

Permalink
Error handling improvement for asserter package (#436)
Browse files Browse the repository at this point in the history
* feat: asserter error handling

Signed-off-by: Jingfu Wang <jingfu.wang@coinbase.com>

* fix: make gen

Signed-off-by: Jingfu Wang <jingfu.wang@coinbase.com>

* test: add tests back

Signed-off-by: Jingfu Wang <jingfu.wang@coinbase.com>

Signed-off-by: Jingfu Wang <jingfu.wang@coinbase.com>
  • Loading branch information
GeekArthur authored Aug 26, 2022
1 parent e82db65 commit a48f742
Show file tree
Hide file tree
Showing 19 changed files with 615 additions and 247 deletions.
39 changes: 28 additions & 11 deletions asserter/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,26 @@ func ContainsCurrency(currencies []*types.Currency, currency *types.Currency) bo

// AssertUniqueAmounts returns an error if a slice
// of types.Amount is invalid. It is considered invalid if the same
// currency is returned multiple times (these shoould be
// currency is returned multiple times (these should be
// consolidated) or if a types.Amount is considered invalid.
func AssertUniqueAmounts(amounts []*types.Amount) error {
seen := map[string]struct{}{}
for _, amount := range amounts {
// Ensure a currency is used at most once
key := types.Hash(amount.Currency)
if _, ok := seen[key]; ok {
return fmt.Errorf("currency %+v used multiple times", amount.Currency)
return fmt.Errorf(
"amount currency %s of amount %s is invalid: %w",
types.PrintStruct(amount.Currency),
types.PrintStruct(amount),
ErrCurrencyUsedMultipleTimes,
)
}
seen[key] = struct{}{}

// Check amount for validity
if err := Amount(amount); err != nil {
return err
return fmt.Errorf("amount %s is invalid: %w", types.PrintStruct(amount), err)
}
}

Expand All @@ -83,11 +88,19 @@ func AccountBalanceResponse(
response *types.AccountBalanceResponse,
) error {
if err := BlockIdentifier(response.BlockIdentifier); err != nil {
return fmt.Errorf("%w: block identifier is invalid", err)
return fmt.Errorf(
"block identifier %s is invalid: %w",
types.PrintStruct(response.BlockIdentifier),
err,
)
}

if err := AssertUniqueAmounts(response.Balances); err != nil {
return fmt.Errorf("%w: balance amounts are invalid", err)
return fmt.Errorf(
"balance amounts %s are invalid: %w",
types.PrintStruct(response.Balances),
err,
)
}

if requestBlock == nil {
Expand All @@ -96,19 +109,19 @@ func AccountBalanceResponse(

if requestBlock.Hash != nil && *requestBlock.Hash != response.BlockIdentifier.Hash {
return fmt.Errorf(
"%w: requested block hash %s but got %s",
ErrReturnedBlockHashMismatch,
"requested block hash %s, but got %s: %w",
*requestBlock.Hash,
response.BlockIdentifier.Hash,
ErrReturnedBlockHashMismatch,
)
}

if requestBlock.Index != nil && *requestBlock.Index != response.BlockIdentifier.Index {
return fmt.Errorf(
"%w: requested block index %d but got %d",
ErrReturnedBlockIndexMismatch,
"requested block index %d, but got %d: %w",
*requestBlock.Index,
response.BlockIdentifier.Index,
ErrReturnedBlockIndexMismatch,
)
}

Expand All @@ -121,11 +134,15 @@ func AccountCoinsResponse(
response *types.AccountCoinsResponse,
) error {
if err := BlockIdentifier(response.BlockIdentifier); err != nil {
return fmt.Errorf("%w: block identifier is invalid", err)
return fmt.Errorf(
"block identifier %s is invalid: %w",
types.PrintStruct(response.BlockIdentifier),
err,
)
}

if err := Coins(response.Coins); err != nil {
return fmt.Errorf("%w: coins are invalid", err)
return fmt.Errorf("coins %s are invalid: %w", types.PrintStruct(response.Coins), err)
}

return nil
Expand Down
10 changes: 5 additions & 5 deletions asserter/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func TestAccountBalance(t *testing.T) {
validAmount,
validAmount,
},
err: fmt.Errorf("currency %+v used multiple times", validAmount.Currency),
err: ErrCurrencyUsedMultipleTimes,
},
"valid historical request index": {
requestBlock: &types.PartialBlockIdentifier{
Expand Down Expand Up @@ -322,10 +322,10 @@ func TestAccountBalance(t *testing.T) {
validAmount,
},
err: fmt.Errorf(
"%w: requested block index %d but got %d",
ErrReturnedBlockIndexMismatch,
"requested block index %d, but got %d: %w",
invalidIndex,
validBlock.Index,
ErrReturnedBlockIndexMismatch,
),
},
"invalid historical request hash": {
Expand All @@ -338,10 +338,10 @@ func TestAccountBalance(t *testing.T) {
validAmount,
},
err: fmt.Errorf(
"%w: requested block hash %s but got %s",
ErrReturnedBlockHashMismatch,
"requested block hash %s, but got %s: %w",
invalidHash,
validBlock.Hash,
ErrReturnedBlockHashMismatch,
),
},
}
Expand Down
74 changes: 55 additions & 19 deletions asserter/asserter.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,20 @@ func NewServer(
validationFilePath string,
) (*Asserter, error) {
if err := OperationTypes(supportedOperationTypes); err != nil {
return nil, err
return nil, fmt.Errorf("operation types %v are invalid: %w", supportedOperationTypes, err)
}

if err := SupportedNetworks(supportedNetworks); err != nil {
return nil, err
return nil, fmt.Errorf(
"network identifiers %s are invalid: %w",
types.PrintStruct(supportedNetworks),
err,
)
}

validationConfig, err := getValidationConfig(validationFilePath)
if err != nil {
return nil, err
return nil, fmt.Errorf("config %s is invalid: %w", validationFilePath, err)
}

callMap := map[string]struct{}{}
Expand All @@ -100,7 +104,7 @@ func NewServer(
}

if _, ok := callMap[method]; ok {
return nil, fmt.Errorf("%w: %s", ErrCallMethodDuplicate, method)
return nil, fmt.Errorf("failed to call method %s: %w", method, ErrCallMethodDuplicate)
}

callMap[method] = struct{}{}
Expand All @@ -126,20 +130,32 @@ func NewClientWithResponses(
validationFilePath string,
) (*Asserter, error) {
if err := NetworkIdentifier(network); err != nil {
return nil, err
return nil, fmt.Errorf(
"network identifier %s is invalid: %w",
types.PrintStruct(network),
err,
)
}

if err := NetworkStatusResponse(networkStatus); err != nil {
return nil, err
return nil, fmt.Errorf(
"network status response %s is invalid: %w",
types.PrintStruct(networkStatus),
err,
)
}

if err := NetworkOptionsResponse(networkOptions); err != nil {
return nil, err
return nil, fmt.Errorf(
"network options response %s is invalid: %w",
types.PrintStruct(networkOptions),
err,
)
}

validationConfig, err := getValidationConfig(validationFilePath)
if err != nil {
return nil, err
return nil, fmt.Errorf("config %s is invalid: %w", validationFilePath, err)
}

return NewClientWithOptions(
Expand Down Expand Up @@ -175,12 +191,12 @@ func NewClientWithFile(
) (*Asserter, error) {
content, err := ioutil.ReadFile(path.Clean(filePath))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to read file %s: %w", filePath, err)
}

config := &Configuration{}
if err := json.Unmarshal(content, config); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal content of file %s: %w", filePath, err)
}

return NewClientWithOptions(
Expand Down Expand Up @@ -209,19 +225,35 @@ func NewClientWithOptions(
validationConfig *Validations,
) (*Asserter, error) {
if err := NetworkIdentifier(network); err != nil {
return nil, err
return nil, fmt.Errorf(
"network identifier %s is invalid: %w",
types.PrintStruct(network),
err,
)
}

if err := BlockIdentifier(genesisBlockIdentifier); err != nil {
return nil, err
return nil, fmt.Errorf(
"genesis block identifier %s is invalid: %w",
types.PrintStruct(genesisBlockIdentifier),
err,
)
}

if err := OperationStatuses(operationStatuses); err != nil {
return nil, err
return nil, fmt.Errorf(
"operation statuses %s are invalid: %w",
types.PrintStruct(operationStatuses),
err,
)
}

if err := OperationTypes(operationTypes); err != nil {
return nil, err
return nil, fmt.Errorf(
"operation types %s are invalid: %w",
types.PrintStruct(operationTypes),
err,
)
}

// TimestampStartIndex defaults to genesisIndex + 1 (this
Expand All @@ -230,9 +262,9 @@ func NewClientWithOptions(
if timestampStartIndex != nil {
if *timestampStartIndex < 0 {
return nil, fmt.Errorf(
"%w: %d",
ErrTimestampStartIndexInvalid,
"failed to validate index %d: %w",
*timestampStartIndex,
ErrTimestampStartIndexInvalid,
)
}

Expand Down Expand Up @@ -304,7 +336,7 @@ func (a *Asserter) OperationSuccessful(operation *types.Operation) (bool, error)

val, ok := a.operationStatusMap[*operation.Status]
if !ok {
return false, fmt.Errorf("%s not found", *operation.Status)
return false, fmt.Errorf("operation status %s is not found", *operation.Status)
}

return val, nil
Expand All @@ -317,11 +349,15 @@ func getValidationConfig(validationFilePath string) (*Validations, error) {
if validationFilePath != "" {
content, err := ioutil.ReadFile(path.Clean(validationFilePath))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to read file %s: %w", validationFilePath, err)
}

if err := json.Unmarshal(content, validationConfig); err != nil {
return nil, err
return nil, fmt.Errorf(
"failed to unmarshal content of file %s: %w",
validationFilePath,
err,
)
}
}
return validationConfig, nil
Expand Down
6 changes: 3 additions & 3 deletions asserter/asserter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ func TestNew(t *testing.T) {
networkOptions: negativeStartIndex,
validationFilePath: "",

err: errors.New("TimestampStartIndex is invalid: -1"),
err: ErrTimestampStartIndexInvalid,
},
}

Expand All @@ -356,7 +356,7 @@ func TestNew(t *testing.T) {

if test.err != nil {
assert.Error(t, err)
assert.Contains(t, test.err.Error(), err.Error())
assert.Contains(t, err.Error(), test.err.Error())
return
}
assert.NoError(t, err)
Expand Down Expand Up @@ -427,7 +427,7 @@ func TestNew(t *testing.T) {

if test.err != nil {
assert.Error(t, err)
assert.Contains(t, test.err.Error(), err.Error())
assert.Contains(t, err.Error(), test.err.Error())
return
}
assert.NoError(t, err)
Expand Down
Loading

0 comments on commit a48f742

Please sign in to comment.