Skip to content

Commit

Permalink
return sql.ErrNotFound if account doesn't exist (#6508)
Browse files Browse the repository at this point in the history
## Motivation
The decision how to handle a situation when the queried account doesn't exist should be left to the caller - it's not the responsibility of the code accessing the cache/database.
  • Loading branch information
poszu committed Dec 2, 2024
1 parent feda54f commit 9bbdc0f
Show file tree
Hide file tree
Showing 14 changed files with 124 additions and 48 deletions.
6 changes: 0 additions & 6 deletions common/types/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,6 @@ func (a Address) String() string {
return result
}

// Format implements fmt.Formatter, forcing the byte slice to be formatted as is,
// without going through the stringer interface used for logging.
func (a Address) Format(s fmt.State, c rune) {
fmt.Fprintf(s, "%"+string(c), a[:])
}

// EncodeScale implements scale codec interface.
func (a *Address) EncodeScale(e *scale.Encoder) (int, error) {
return scale.EncodeByteArray(e, a[:])
Expand Down
12 changes: 0 additions & 12 deletions common/types/hashes.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ func (h Hash20) ShortString() string {
return hex.EncodeToString(h[:5])
}

// Format implements fmt.Formatter, forcing the byte slice to be formatted as is,
// without going through the stringer interface used for logging.
func (h Hash20) Format(s fmt.State, c rune) {
fmt.Fprintf(s, "%"+string(c), h[:])
}

// UnmarshalText parses a hash in hex syntax.
func (h *Hash20) UnmarshalText(input []byte) error {
if err := util.UnmarshalFixedText("Hash", input, h[:]); err != nil {
Expand Down Expand Up @@ -163,12 +157,6 @@ func (h Hash32) ShortString() string {
return hex.EncodeToString(h[:5])
}

// Format implements fmt.Formatter, forcing the byte slice to be formatted as is,
// without going through the stringer interface used for logging.
func (h Hash32) Format(s fmt.State, c rune) {
fmt.Fprintf(s, "%"+string(c), h[:])
}

// UnmarshalText parses a hash in hex syntax.
func (h *Hash32) UnmarshalText(input []byte) error {
if err := util.UnmarshalFixedText("Hash", input, h[:]); err != nil {
Expand Down
10 changes: 10 additions & 0 deletions common/types/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package types

import (
"crypto/rand"
"testing"

"github.com/stretchr/testify/require"
)

// RandomBytes generates random data in bytes for testing.
Expand Down Expand Up @@ -137,3 +140,10 @@ func RandomVrfSignature() VrfSignature {
}
return VrfSignature(b)
}

func RandomAddress(tb testing.TB) Address {
var a Address
_, err := rand.Read(a[:])
require.NoError(tb, err)
return a
}
6 changes: 5 additions & 1 deletion genvm/core/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package core

import (
"bytes"
"errors"
"fmt"

"github.com/spacemeshos/go-scale"
Expand Down Expand Up @@ -235,7 +236,10 @@ func (c *Context) load(address types.Address) (*Account, error) {
account, exist := c.changed[address]
if !exist {
loaded, err := c.Loader.Get(address)
if err != nil {
switch {
case errors.Is(err, ErrNotFound):
loaded = types.Account{Address: address}
case err != nil:
return nil, fmt.Errorf("%w: %w", ErrInternal, err)
}
account = &loaded
Expand Down
6 changes: 2 additions & 4 deletions genvm/core/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,8 @@ func TestRelay(t *testing.T) {
require.Equal(t, amount1, int(rec1state.Balance))
require.NotEqual(t, encoded, rec1state.State)

rec2state, err := cache.Get(receiver2)
require.NoError(t, err)
require.Equal(t, 0, int(rec2state.Balance))
require.NotEqual(t, encoded, rec2state.State)
_, err = cache.Get(receiver2)
require.ErrorIs(t, err, core.ErrNotFound) // relay to receiver2 failed

remoteState, err := cache.Get(remote)
require.NoError(t, err)
Expand Down
8 changes: 7 additions & 1 deletion genvm/core/staged_cache.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package core

import (
"errors"

"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/sql"
"github.com/spacemeshos/go-spacemesh/sql/accounts"
Expand All @@ -11,7 +13,11 @@ type DBLoader struct {
}

func (db DBLoader) Get(address types.Address) (types.Account, error) {
return accounts.Latest(db.Executor, address)
account, err := accounts.Latest(db.Executor, address)
if errors.Is(err, sql.ErrNotFound) {
return types.Account{}, ErrNotFound
}
return account, err
}

// NewStagedCache returns instance of the staged cache.
Expand Down
1 change: 1 addition & 0 deletions genvm/core/staged_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ func TestCacheGetCopies(t *testing.T) {
db := statesql.InMemoryTest(t)
ss := core.NewStagedCache(core.DBLoader{db})
address := core.Address{1}
ss.Update(core.Account{Address: address})
account, err := ss.Get(address)
require.NoError(t, err)
account.Balance = 100
Expand Down
7 changes: 7 additions & 0 deletions genvm/core/types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package core

import (
"errors"

"github.com/spacemeshos/go-scale"

"github.com/spacemeshos/go-spacemesh/common/types"
Expand Down Expand Up @@ -76,8 +78,13 @@ type Template interface {
Verify(Host, []byte, *scale.Decoder) bool
}

var ErrNotFound = errors.New("not found")

// AccountLoader is an interface for loading accounts.
type AccountLoader interface {
// Get account for given address
//
// Returns ErrNotFound if the account doesn't exist.
Get(Address) (Account, error)
}

Expand Down
6 changes: 5 additions & 1 deletion genvm/rewards.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package vm

import (
"errors"
"fmt"
"math/big"

Expand Down Expand Up @@ -67,7 +68,10 @@ func (v *VM) addRewards(
}
result = append(result, reward)
account, err := ss.Get(blockReward.Coinbase)
if err != nil {
switch {
case errors.Is(err, core.ErrNotFound):
account = types.Account{Address: blockReward.Coinbase}
case err != nil:
return nil, fmt.Errorf("%w: %w", core.ErrInternal, err)
}
account.Balance += reward.TotalReward
Expand Down
17 changes: 13 additions & 4 deletions genvm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ func (v *VM) AccountExists(address core.Address) (bool, error) {
// GetNonce returns expected next nonce for the address.
func (v *VM) GetNonce(address core.Address) (core.Nonce, error) {
account, err := accounts.Latest(v.db, address)
if err != nil {
switch {
case errors.Is(err, sql.ErrNotFound):
return 0, nil
case err != nil:
return 0, err
}
return account.NextNonce, nil
Expand All @@ -165,7 +168,10 @@ func (v *VM) GetNonce(address core.Address) (core.Nonce, error) {
// GetBalance returns balance for an address.
func (v *VM) GetBalance(address types.Address) (uint64, error) {
account, err := accounts.Latest(v.db, address)
if err != nil {
switch {
case errors.Is(err, sql.ErrNotFound):
return 0, nil
case err != nil:
return 0, err
}
return account.Balance, nil
Expand Down Expand Up @@ -501,9 +507,12 @@ func parse(
return nil, nil, nil, fmt.Errorf("%w: failed to decode method selector %w", core.ErrMalformed, err)
}
account, err := loader.Get(principal)
if err != nil {
switch {
case errors.Is(err, core.ErrNotFound):
account = types.Account{Address: principal}
case err != nil:
return nil, nil, nil, fmt.Errorf(
"%w: failed load state for principal %s - %w",
"%w: failed load state for principal %s: %w",
core.ErrInternal,
principal,
err,
Expand Down
10 changes: 8 additions & 2 deletions genvm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/spacemeshos/go-spacemesh/genvm/templates/wallet"
"github.com/spacemeshos/go-spacemesh/hash"
"github.com/spacemeshos/go-spacemesh/signing"
"github.com/spacemeshos/go-spacemesh/sql"
"github.com/spacemeshos/go-spacemesh/sql/accounts"
"github.com/spacemeshos/go-spacemesh/sql/layers"
"github.com/spacemeshos/go-spacemesh/sql/statesql"
Expand Down Expand Up @@ -1441,9 +1442,13 @@ func runTestCases(t *testing.T, tcs []templateTestCase, genTester func(t *testin
}
for account, changes := range layer.expected {
prev, err := accounts.Get(tt.db, tt.accounts[account].getAddress(), lid.Sub(1))
require.NoError(tt, err)
if err != nil {
require.ErrorIs(t, err, sql.ErrNotFound)
}
current, err := accounts.Get(tt.db, tt.accounts[account].getAddress(), lid)
require.NoError(tt, err)
if err != nil {
require.ErrorIs(t, err, sql.ErrNotFound)
}
tt.Logf("verifying account index=%d in layer index=%d", account, i)
changes.verify(tt, &prev, &current)
}
Expand Down Expand Up @@ -1652,6 +1657,7 @@ func testValidation(t *testing.T, tt *tester, template core.Address) {
if tc.err != nil {
require.ErrorIs(t, err, tc.err)
} else {
require.NoError(t, err)
require.Equal(t, tc.verified, req.Verify())
if tc.verified {
require.Equal(t, tc.header, header)
Expand Down
2 changes: 1 addition & 1 deletion node/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func runRelay(ctx context.Context, cfg *config.Config) error {

types.SetLayersPerEpoch(cfg.LayersPerEpoch)
prologue := fmt.Sprintf("%x-%v",
cfg.Genesis.GenesisID(),
cfg.Genesis.GenesisID().Bytes(),
types.GetEffectiveGenesis(),
)
// Prevent testnet nodes from working on the mainnet, but
Expand Down
24 changes: 8 additions & 16 deletions sql/accounts/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func Has(db sql.Executor, address types.Address) (bool, error) {
// Latest latest account data for an address.
func Latest(db sql.Executor, address types.Address) (types.Account, error) {
var account types.Account
_, err := db.Exec(`
rows, err := db.Exec(`
select balance, next_nonce, layer_updated, template, state from accounts
where address = ?1
order by layer_updated desc;`,
Expand All @@ -46,20 +46,16 @@ func Latest(db sql.Executor, address types.Address) (types.Account, error) {
if err != nil {
return types.Account{}, fmt.Errorf("failed to load %v: %w", address, err)
}
// TODO(mafa): returning `sql.ErrNotFound` causes a bunch of tests to fail, some even panic
// this needs to be investigated and fixed
//
// if account.Address != address {
// return types.Account{}, sql.ErrNotFound
// }
account.Address = address // without this tests are failing not only assertions but are also panicking
if rows == 0 {
return types.Account{}, sql.ErrNotFound
}
return account, nil
}

// Get account data that was valid at the specified layer.
func Get(db sql.Executor, address types.Address, layer types.LayerID) (types.Account, error) {
var account types.Account
_, err := db.Exec(`
rows, err := db.Exec(`
select balance, next_nonce, layer_updated, template, state from accounts
where address = ?1 and layer_updated <= ?2
order by layer_updated desc;`,
Expand All @@ -84,13 +80,9 @@ func Get(db sql.Executor, address types.Address, layer types.LayerID) (types.Acc
if err != nil {
return types.Account{}, fmt.Errorf("failed to load %v for layer %v: %w", address, layer, err)
}
// TODO(mafa): returning `sql.ErrNotFound` causes a bunch of tests to fail, some even panic
// this needs to be investigated and fixed
//
// if account.Address != address {
// return types.Account{}, sql.ErrNotFound
// }
account.Address = address // without this tests are failing not only assertions but are also panicking
if rows == 0 {
return types.Account{}, sql.ErrNotFound
}
return account, nil
}

Expand Down
57 changes: 57 additions & 0 deletions sql/accounts/accounts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,63 @@ func TestHas(t *testing.T) {
require.True(t, has)
}

func TestLatest(t *testing.T) {
t.Run("doesn't exist", func(t *testing.T) {
db := statesql.InMemoryTest(t)
account, err := Latest(db, types.RandomAddress(t))
require.ErrorIs(t, err, sql.ErrNotFound)
require.Empty(t, account)
})
t.Run("picks latest", func(t *testing.T) {
address := types.RandomAddress(t)
db := statesql.InMemoryTest(t)
err := Update(db, &types.Account{
Address: address,
})
require.NoError(t, err)
account := types.Account{
Layer: 1,
NextNonce: 1,
Balance: 100,
Address: address,
}
err = Update(db, &account)
require.NoError(t, err)

got, err := Latest(db, address)
require.NoError(t, err)
require.Equal(t, account, got)
})
}

func TestGet(t *testing.T) {
t.Run("doesn't exist", func(t *testing.T) {
db := statesql.InMemoryTest(t)
account, err := Get(db, types.RandomAddress(t), 0)
require.ErrorIs(t, err, sql.ErrNotFound)
require.Empty(t, account)
})
t.Run("picks the right one", func(t *testing.T) {
address := types.RandomAddress(t)
db := statesql.InMemoryTest(t)
account := types.Account{
Layer: 1,
NextNonce: 1,
Balance: 100,
Address: address,
}
err := Update(db, &account)
require.NoError(t, err)

_, err = Get(db, address, 0)
require.ErrorIs(t, err, sql.ErrNotFound)

got, err := Get(db, address, 1)
require.NoError(t, err)
require.Equal(t, account, got)
})
}

func TestRevert(t *testing.T) {
address := types.Address{1, 1}
seq := genSeq(address, 10)
Expand Down

0 comments on commit 9bbdc0f

Please sign in to comment.