Skip to content

Commit

Permalink
Decouple TestVerify into TestValidateBasic and TestVerify (cosmos#1231)
Browse files Browse the repository at this point in the history
<!--
Please read and fill out this form before submitting your PR.

Please make sure you have reviewed our contributors guide before
submitting your
first PR.
-->

## Overview
Closes: cosmos#1189 

<!-- 
Please provide an explanation of the PR, including the appropriate
context,
background, goal, and rationale. If there is an issue with this
information,
please provide a tl;dr and link the issue. 
-->

## Checklist

<!-- 
Please complete the checklist to ensure that the PR is ready to be
reviewed.

IMPORTANT:
PRs should be left in Draft until the below checklist is completed.
-->

- [x] New and updated code has appropriate documentation
- [x] New and updated code has new and/or updated testing
- [x] Required CI checks are passing
- [ ] Visual proof for any user facing features like CLI or
documentation updates
- [x] Linked issues closed with keywords
  • Loading branch information
Manav-Aggarwal authored Oct 6, 2023
1 parent 3e7ff30 commit 0f49204
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 50 deletions.
11 changes: 1 addition & 10 deletions types/signed_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,7 @@ var (
)

func (sh *SignedHeader) Verify(untrstH *SignedHeader) error {
// TODO(@Wondertan):
// We keep this check because of how unit tests are structured where TestVerify tests both ValidateBasic and Verify.
// While the check is redundant as go-header ensures untrustH passed ValidateBasic before.
// Decoupling TestVerify into TestValidateBasic and TestVerify would allow to remove this check.
if err := untrstH.ValidateBasic(); err != nil {
return &header.VerifyError{
Reason: err,
}
}

// go-header ensures untrustH already passed ValidateBasic.
if err := sh.Header.Verify(&untrstH.Header); err != nil {
return &header.VerifyError{
Reason: err,
Expand Down
121 changes: 82 additions & 39 deletions types/signed_header_test.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
package types

import (
"strconv"
"fmt"
"testing"
"time"

"github.com/celestiaorg/go-header"
"github.com/cometbft/cometbft/crypto/ed25519"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestVerify(t *testing.T) {
func TestSignedHeader(t *testing.T) {
// Generate a random signed header
trusted, privKey, err := GetRandomSignedHeader()
require.NoError(t, err)
time.Sleep(time.Second)
// Get the next random header
untrustedAdj, err := GetNextRandomHeader(trusted, privKey)
require.NoError(t, err)
fakeAggregatorsHash := header.Hash(GetRandomBytes(32))
fakeLastHeaderHash := header.Hash(GetRandomBytes(32))
fakeLastCommitHash := header.Hash(GetRandomBytes(32))
t.Run("Test Verify", func(t *testing.T) {
testVerify(t, trusted, untrustedAdj, privKey)
})
t.Run("Test ValidateBasic", func(t *testing.T) {
testValidateBasic(t, untrustedAdj, privKey)
})
}

func testVerify(t *testing.T, trusted *SignedHeader, untrustedAdj *SignedHeader, privKey ed25519.PrivKey) {
tests := []struct {
prepare func() (*SignedHeader, bool)
err error
prepare func() (*SignedHeader, bool) // Function to prepare the test case
err error // Expected error
}{
{
prepare: func() (*SignedHeader, bool) { return untrustedAdj, false },
Expand All @@ -30,17 +37,7 @@ func TestVerify(t *testing.T) {
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.AggregatorsHash = fakeAggregatorsHash
return &untrusted, false
},
err: &header.VerifyError{
Reason: ErrAggregatorSetHashMismatch,
},
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.LastHeaderHash = fakeLastHeaderHash
untrusted.LastHeaderHash = header.Hash(GetRandomBytes(32))
return &untrusted, true
},
err: &header.VerifyError{
Expand All @@ -50,7 +47,7 @@ func TestVerify(t *testing.T) {
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.LastCommitHash = fakeLastCommitHash
untrusted.LastCommitHash = header.Hash(GetRandomBytes(32))
return &untrusted, true
},
err: &header.VerifyError{
Expand All @@ -68,58 +65,104 @@ func TestVerify(t *testing.T) {
Reason: ErrNonAdjacentHeaders,
},
},
}

for testIndex, test := range tests {
t.Run(fmt.Sprintf("Test #%d", testIndex), func(t *testing.T) {
preparedHeader, shouldRecomputeCommit := test.prepare()

if shouldRecomputeCommit {
commit, err := getCommit(preparedHeader.Header, privKey)
require.NoError(t, err)
preparedHeader.Commit = *commit
}

err := trusted.Verify(preparedHeader)

if test.err == nil {
assert.NoError(t, err)
return
}

if err == nil {
t.Errorf("expected error: %v, but got nil", test.err)
return
}

reason := err.(*header.VerifyError).Reason
expectedReason := test.err.(*header.VerifyError).Reason
assert.ErrorIs(t, reason, expectedReason)
})
}
}

func testValidateBasic(t *testing.T, untrustedAdj *SignedHeader, privKey ed25519.PrivKey) {
// Define test cases
tests := []struct {
prepare func() (*SignedHeader, bool) // Function to prepare the test case
err error // Expected error
}{
{
prepare: func() (*SignedHeader, bool) { return untrustedAdj, false },
err: nil,
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.AggregatorsHash = header.Hash(GetRandomBytes(32))
return &untrusted, false
},
err: ErrAggregatorSetHashMismatch,
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.BaseHeader.ChainID = "toaster"
return &untrusted, false // Signature verification should fail
},
err: &header.VerifyError{
Reason: ErrSignatureVerificationFailed,
},
err: ErrSignatureVerificationFailed,
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.Version.App = untrusted.Version.App + 1
return &untrusted, false // Signature verification should fail
},
err: &header.VerifyError{
Reason: ErrSignatureVerificationFailed,
},
err: ErrSignatureVerificationFailed,
},
{
prepare: func() (*SignedHeader, bool) {
untrusted := *untrustedAdj
untrusted.ProposerAddress = nil
return &untrusted, true
},
err: &header.VerifyError{
Reason: ErrNoProposerAddress,
},
err: ErrNoProposerAddress,
},
}

for i, test := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
preparedHeader, recomputeCommit := test.prepare()
if recomputeCommit {
for testIndex, test := range tests {
t.Run(fmt.Sprintf("Test #%d", testIndex), func(t *testing.T) {
preparedHeader, shouldRecomputeCommit := test.prepare()

if shouldRecomputeCommit {
commit, err := getCommit(preparedHeader.Header, privKey)
require.NoError(t, err)
preparedHeader.Commit = *commit
}
err = trusted.Verify(preparedHeader)

err := preparedHeader.ValidateBasic()

if test.err == nil {
assert.NoError(t, err)
return
}

if err == nil {
t.Errorf("expected err: %v, got nil", test.err)
t.Errorf("expected error: %v, but got nil", test.err)
return
}
reason := err.(*header.VerifyError).Reason
testReason := test.err.(*header.VerifyError).Reason
assert.ErrorIs(t, reason, testReason)

assert.ErrorIs(t, err, test.err)
})
}
}
2 changes: 1 addition & 1 deletion types/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func GetNextRandomHeader(signedHeader *SignedHeader, privKey ed25519.PrivKey) (*
BaseHeader: BaseHeader{
ChainID: "test",
Height: signedHeader.Height() + 1,
Time: uint64(time.Now().UnixNano()),
Time: uint64(signedHeader.Time().Add(1 * time.Second).UnixNano()),
},
LastHeaderHash: signedHeader.Hash(),
DataHash: GetRandomBytes(32),
Expand Down

0 comments on commit 0f49204

Please sign in to comment.