Skip to content

Commit

Permalink
feat: add 'disableChecks' query param on revert endpoint (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
gfyrag authored Sep 14, 2023
1 parent 45e3e8a commit 8285d25
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 76 deletions.
6 changes: 6 additions & 0 deletions pkg/api/controllers/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,12 @@ paths:
format: int64
minimum: 0
example: 1234
- name: disableChecks
in: query
description: Allow to disable balances checks
required: false
schema:
type: boolean
responses:
"200":
description: OK
Expand Down
8 changes: 5 additions & 3 deletions pkg/api/controllers/transaction_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func (ctl *TransactionController) PostTransaction(c *gin.Context) {
Reference: payload.Reference,
Metadata: payload.Metadata,
}
res, err := l.(*ledger.Ledger).ExecuteTxsData(c.Request.Context(), preview, txData)
res, err := l.(*ledger.Ledger).ExecuteTxsData(c.Request.Context(), preview, true, txData)
if err != nil {
apierrors.ResponseError(c, err)
return
Expand Down Expand Up @@ -326,7 +326,9 @@ func (ctl *TransactionController) RevertTransaction(c *gin.Context) {
return
}

tx, err := l.(*ledger.Ledger).RevertTransaction(c.Request.Context(), txId)
disableChecks := c.Query("disableChecks") == "1" || c.Query("disableChecks") == "true"

tx, err := l.(*ledger.Ledger).RevertTransaction(c.Request.Context(), txId, !disableChecks)
if err != nil {
apierrors.ResponseError(c, err)
return
Expand Down Expand Up @@ -392,7 +394,7 @@ func (ctl *TransactionController) PostTransactionsBatch(c *gin.Context) {
}
}

res, err := l.(*ledger.Ledger).ExecuteTxsData(c.Request.Context(), false, txs.Transactions...)
res, err := l.(*ledger.Ledger).ExecuteTxsData(c.Request.Context(), false, true, txs.Transactions...)
if err != nil {
apierrors.ResponseError(c, err)
return
Expand Down
47 changes: 44 additions & 3 deletions pkg/api/controllers/transaction_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,47 @@ func TestPostTransactionsOverdraft(t *testing.T) {
}))
}

func TestRevertWithDisableChecks(t *testing.T) {
internal.RunTest(t, fx.Invoke(func(lc fx.Lifecycle, api *api.API, driver storage.Driver[ledger.Store]) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {

rsp := internal.PostTransaction(t, api, controllers.PostTransaction{
Script: core.Script{
Plain: `
send [USD/2 100] (
source = @world
destination = @users:43
)
`,
},
}, false)
require.Equal(t, http.StatusOK, rsp.Result().StatusCode)
txs, ok := internal.DecodeSingleResponse[[]core.ExpandedTransaction](t, rsp.Body)
require.True(t, ok)
require.Len(t, txs, 1)

rsp = internal.PostTransaction(t, api, controllers.PostTransaction{
Script: core.Script{
Plain: `
send [USD/2 100] (
source = @users:43
destination = @blackhole
)
`,
},
}, false)
require.Equal(t, http.StatusOK, rsp.Result().StatusCode)

rsp = internal.RevertTransaction(api, txs[0].ID, true)
require.Equal(t, http.StatusOK, rsp.Result().StatusCode)

return nil
},
})
}))
}

func TestPostTransactionInvalidBody(t *testing.T) {
internal.RunTest(t, fx.Invoke(func(lc fx.Lifecycle, api *api.API) {
lc.Append(fx.Hook{
Expand Down Expand Up @@ -2064,7 +2105,7 @@ func TestRevertTransaction(t *testing.T) {
revertedTxID := cursor.Data[0].ID

t.Run("first revert should succeed", func(t *testing.T) {
rsp := internal.RevertTransaction(api, revertedTxID)
rsp := internal.RevertTransaction(api, revertedTxID, false)
require.Equal(t, http.StatusOK, rsp.Result().StatusCode)
res, _ := internal.DecodeSingleResponse[core.ExpandedTransaction](t, rsp.Body)
require.Equal(t, revertedTxID+1, res.ID)
Expand All @@ -2090,7 +2131,7 @@ func TestRevertTransaction(t *testing.T) {
})

t.Run("transaction not found", func(t *testing.T) {
rsp := internal.RevertTransaction(api, uint64(42))
rsp := internal.RevertTransaction(api, uint64(42), false)
require.Equal(t, http.StatusNotFound, rsp.Result().StatusCode, rsp.Body.String())
err := sharedapi.ErrorResponse{}
internal.Decode(t, rsp.Body, &err)
Expand All @@ -2103,7 +2144,7 @@ func TestRevertTransaction(t *testing.T) {
})

t.Run("second revert should fail", func(t *testing.T) {
rsp := internal.RevertTransaction(api, revertedTxID)
rsp := internal.RevertTransaction(api, revertedTxID, false)
require.Equal(t, http.StatusBadRequest, rsp.Result().StatusCode, rsp.Body.String())

err := sharedapi.ErrorResponse{}
Expand Down
7 changes: 6 additions & 1 deletion pkg/api/internal/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,13 @@ func GetTransaction(handler http.Handler, id uint64) *httptest.ResponseRecorder
return rec
}

func RevertTransaction(handler http.Handler, id uint64) *httptest.ResponseRecorder {
func RevertTransaction(handler http.Handler, id uint64, disableChecks bool) *httptest.ResponseRecorder {
req, rec := NewRequest(http.MethodPost, fmt.Sprintf("/"+testingLedger+"/transactions/%d/revert", id), nil)
if disableChecks {
query := req.URL.Query()
query.Set("disableChecks", "1")
req.URL.RawQuery = query.Encode()
}
handler.ServeHTTP(rec, req)
return rec
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/ledger/benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func BenchmarkLedger_PostTransactions_Postings_Single_FixedAccounts(b *testing.B
for n := 0; n < b.N; n++ {
_, err := txData.Postings.Validate()
require.NoError(b, err)
res, err = l.ExecuteTxsData(context.Background(), true, txData)
res, err = l.ExecuteTxsData(context.Background(), true, true, txData)
require.NoError(b, err)
require.Len(b, res, 1)
require.Len(b, res[0].Postings, nbPostings)
Expand Down Expand Up @@ -96,7 +96,7 @@ func BenchmarkLedger_PostTransactions_Postings_Batch_FixedAccounts(b *testing.B)
_, err := txData.Postings.Validate()
require.NoError(b, err)
}
res, err = l.ExecuteTxsData(context.Background(), true, txsData...)
res, err = l.ExecuteTxsData(context.Background(), true, true, txsData...)
require.NoError(b, err)
require.Len(b, res, 7)
require.Len(b, res[0].Postings, 1)
Expand Down Expand Up @@ -137,7 +137,7 @@ func BenchmarkLedger_PostTransactions_Postings_Batch_VaryingAccounts(b *testing.
_, err := txData.Postings.Validate()
require.NoError(b, err)
}
res, err = l.ExecuteTxsData(context.Background(), true, txsData...)
res, err = l.ExecuteTxsData(context.Background(), true, true, txsData...)
require.NoError(b, err)
require.Len(b, res, 7)
require.Len(b, res[0].Postings, 1)
Expand Down
18 changes: 9 additions & 9 deletions pkg/ledger/execute_script_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestMappingIgnoreDestinations(t *testing.T) {
_, err := l.ExecuteScript(context.Background(), false, script)
require.NoError(t, err)

_, err = l.ExecuteTxsData(context.Background(), false, core.TransactionData{
_, err = l.ExecuteTxsData(context.Background(), false, true, core.TransactionData{
Postings: []core.Posting{{
Source: "B",
Destination: "A",
Expand All @@ -66,7 +66,7 @@ func TestMappingIgnoreDestinations(t *testing.T) {
})
require.NoError(t, err)

_, err = l.ExecuteTxsData(context.Background(), false, core.TransactionData{
_, err = l.ExecuteTxsData(context.Background(), false, true, core.TransactionData{
Postings: []core.Posting{{
Source: "B",
Destination: "A",
Expand Down Expand Up @@ -278,7 +278,7 @@ func TestEnoughFunds(t *testing.T) {
},
}

_, err := l.ExecuteTxsData(context.Background(), false, tx)
_, err := l.ExecuteTxsData(context.Background(), false, true, tx)
require.NoError(t, err)

script := core.ScriptData{
Expand Down Expand Up @@ -309,7 +309,7 @@ func TestNotEnoughFunds(t *testing.T) {
},
}

_, err := l.ExecuteTxsData(context.Background(), false, tx)
_, err := l.ExecuteTxsData(context.Background(), false, true, tx)
require.NoError(t, err)

script := core.ScriptData{
Expand Down Expand Up @@ -366,7 +366,7 @@ func TestMetadata(t *testing.T) {
},
}

_, err := l.ExecuteTxsData(context.Background(), false, tx)
_, err := l.ExecuteTxsData(context.Background(), false, true, tx)
require.NoError(t, err)

err = l.SaveMeta(context.Background(), core.MetaTargetTypeAccount,
Expand Down Expand Up @@ -624,7 +624,7 @@ func TestMonetaryVariableBalance(t *testing.T) {
},
},
}
_, err := l.ExecuteTxsData(context.Background(), false, tx)
_, err := l.ExecuteTxsData(context.Background(), false, true, tx)
require.NoError(t, err)

script := core.ScriptData{
Expand Down Expand Up @@ -665,7 +665,7 @@ func TestMonetaryVariableBalance(t *testing.T) {
},
},
}
_, err := l.ExecuteTxsData(context.Background(), false, tx)
_, err := l.ExecuteTxsData(context.Background(), false, true, tx)
require.NoError(t, err)

script := core.ScriptData{
Expand Down Expand Up @@ -706,7 +706,7 @@ func TestMonetaryVariableBalance(t *testing.T) {
},
},
}
_, err := l.ExecuteTxsData(context.Background(), false, tx)
_, err := l.ExecuteTxsData(context.Background(), false, true, tx)
require.NoError(t, err)

script := core.ScriptData{
Expand Down Expand Up @@ -742,7 +742,7 @@ func TestMonetaryVariableBalance(t *testing.T) {
},
},
}
_, err := l.ExecuteTxsData(context.Background(), false, tx)
_, err := l.ExecuteTxsData(context.Background(), false, true, tx)
require.NoError(t, err)

script := core.ScriptData{
Expand Down
60 changes: 31 additions & 29 deletions pkg/ledger/execute_txsdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/pkg/errors"
)

func (l *Ledger) ExecuteTxsData(ctx context.Context, preview bool, txsData ...core.TransactionData) ([]core.ExpandedTransaction, error) {
func (l *Ledger) ExecuteTxsData(ctx context.Context, preview, checkBalances bool, txsData ...core.TransactionData) ([]core.ExpandedTransaction, error) {
ctx, span := opentelemetry.Start(ctx, "ExecuteTxsData")
defer span.End()

Expand Down Expand Up @@ -98,38 +98,40 @@ func (l *Ledger) ExecuteTxsData(ctx context.Context, preview bool, txsData ...co
}
}

for account, volumes := range txVolumeAggr.PostCommitVolumes {
if _, ok := accs[account]; !ok {
accs[account], err = l.GetAccount(ctx, account)
if err != nil {
return []core.ExpandedTransaction{}, NewTransactionCommitError(i,
errors.Wrap(err, fmt.Sprintf("get account '%s'", account)))
}
}
for asset, vol := range volumes {
accs[account].Volumes[asset] = vol
}
accs[account].Balances = accs[account].Volumes.Balances()
for asset, volume := range volumes {
if account == core.WORLD {
continue
if checkBalances {
for account, volumes := range txVolumeAggr.PostCommitVolumes {
if _, ok := accs[account]; !ok {
accs[account], err = l.GetAccount(ctx, account)
if err != nil {
return []core.ExpandedTransaction{}, NewTransactionCommitError(i,
errors.Wrap(err, fmt.Sprintf("get account '%s'", account)))
}
}
if volume.Balance().Gte(txVolumeAggr.PreCommitVolumes[account][asset].Balance()) {
continue
for asset, vol := range volumes {
accs[account].Volumes[asset] = vol
}
accs[account].Balances = accs[account].Volumes.Balances()
for asset, volume := range volumes {
if account == core.WORLD {
continue
}
if volume.Balance().Gte(txVolumeAggr.PreCommitVolumes[account][asset].Balance()) {
continue
}

for _, contract := range contracts {
if contract.Match(account) {
if ok := contract.Expr.Eval(core.EvalContext{
Variables: map[string]interface{}{
"balance": volume.Balance(),
},
Metadata: accs[account].Metadata,
Asset: asset,
}); !ok {
return []core.ExpandedTransaction{}, NewInsufficientFundError(asset)
for _, contract := range contracts {
if contract.Match(account) {
if ok := contract.Expr.Eval(core.EvalContext{
Variables: map[string]interface{}{
"balance": volume.Balance(),
},
Metadata: accs[account].Metadata,
Asset: asset,
}); !ok {
return []core.ExpandedTransaction{}, NewInsufficientFundError(asset)
}
break
}
break
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/ledger/execute_txsdata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func TestLedger_ExecuteTxsData(t *testing.T) {
},
}

res, err := l.ExecuteTxsData(context.Background(), true, txsData...)
res, err := l.ExecuteTxsData(context.Background(), true, true, txsData...)
assert.NoError(t, err)

assert.Equal(t, len(txsData), len(res))
Expand Down Expand Up @@ -185,7 +185,7 @@ func TestLedger_ExecuteTxsData(t *testing.T) {
},
}

res, err := l.ExecuteTxsData(context.Background(), true, txsData...)
res, err := l.ExecuteTxsData(context.Background(), true, true, txsData...)
require.NoError(t, err)
require.Equal(t, len(txsData), len(res))

Expand Down Expand Up @@ -307,19 +307,19 @@ func TestLedger_ExecuteTxsData(t *testing.T) {
})

t.Run("no transaction data", func(t *testing.T) {
_, err := l.ExecuteTxsData(context.Background(), true)
_, err := l.ExecuteTxsData(context.Background(), true, true)
assert.Error(t, err)
assert.ErrorContains(t, err, "no transaction data to execute")
})

t.Run("no postings", func(t *testing.T) {
_, err := l.ExecuteTxsData(context.Background(), true, core.TransactionData{})
_, err := l.ExecuteTxsData(context.Background(), true, true, core.TransactionData{})
assert.Error(t, err)
assert.ErrorContains(t, err, "executing transaction data 0: no postings")
})

t.Run("amount zero", func(t *testing.T) {
res, err := l.ExecuteTxsData(context.Background(), true, core.TransactionData{
res, err := l.ExecuteTxsData(context.Background(), true, true, core.TransactionData{
Postings: core.Postings{
{
Source: "world",
Expand Down Expand Up @@ -347,7 +347,7 @@ func TestLedger_ExecuteTxsData(t *testing.T) {
},
}))

_, err := l.ExecuteTxsData(context.Background(), true,
_, err := l.ExecuteTxsData(context.Background(), true, true,
core.TransactionData{
Postings: []core.Posting{{
Source: "world",
Expand Down
4 changes: 2 additions & 2 deletions pkg/ledger/ledger.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (l *Ledger) LoadMapping(ctx context.Context) (*core.Mapping, error) {
return l.store.LoadMapping(ctx)
}

func (l *Ledger) RevertTransaction(ctx context.Context, id uint64) (*core.ExpandedTransaction, error) {
func (l *Ledger) RevertTransaction(ctx context.Context, id uint64, checkBalances bool) (*core.ExpandedTransaction, error) {
revertedTx, err := l.store.GetTransaction(ctx, id)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("getting transaction %d", id))
Expand All @@ -119,7 +119,7 @@ func (l *Ledger) RevertTransaction(ctx context.Context, id uint64) (*core.Expand
Reference: rt.Reference,
Metadata: rt.Metadata,
}
res, err := l.ExecuteTxsData(ctx, false, txData)
res, err := l.ExecuteTxsData(ctx, false, checkBalances, txData)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf(
"executing revert script for transaction %d", id))
Expand Down
Loading

0 comments on commit 8285d25

Please sign in to comment.