Skip to content

Commit

Permalink
Merge pull request #271 from NethermindEth/starknet_call
Browse files Browse the repository at this point in the history
Update starknet_call
  • Loading branch information
cicr99 committed Aug 15, 2023
2 parents 2c54dc6 + 26f1bbf commit 53e7d9f
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 60 deletions.
14 changes: 9 additions & 5 deletions account.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const (

type account interface {
TransactionHash(calls []types.FunctionCall, details types.ExecuteDetails) (*big.Int, error)
Call(ctx context.Context, call types.FunctionCall) ([]string, error)
Call(ctx context.Context, call rpc.FunctionCall) ([]*felt.Felt, error)
Nonce(ctx context.Context) (*big.Int, error)
EstimateFee(ctx context.Context, calls []types.FunctionCall, details types.ExecuteDetails) (*types.FeeEstimate, error)
Execute(ctx context.Context, calls []types.FunctionCall, details types.ExecuteDetails) (*types.AddInvokeTransactionOutput, error)
Expand All @@ -44,12 +44,12 @@ type AccountPlugin interface {
type ProviderType string

const (
ProviderRPC ProviderType = "rpc"
ProviderRPC ProviderType = "rpc"
ProviderGateway ProviderType = "gateway"
)

type Account struct {
rpc *rpc.Provider
rpc *rpc.Provider
sequencer *gateway.GatewayProvider
provider ProviderType
chainId string
Expand Down Expand Up @@ -145,7 +145,7 @@ func NewGatewayAccount(sender, address *felt.Felt, ks Keystore, provider *gatewa
return account, nil
}

func (account *Account) Call(ctx context.Context, call types.FunctionCall) ([]string, error) {
func (account *Account) Call(ctx context.Context, call rpc.FunctionCall) ([]*felt.Felt, error) {
switch account.provider {
case ProviderRPC:
if account.rpc == nil {
Expand All @@ -162,7 +162,11 @@ func (account *Account) Call(ctx context.Context, call types.FunctionCall) ([]st
if account.sequencer == nil {
return nil, ErrUnsupportedAccount
}
return account.sequencer.Call(ctx, call, "latest")
resp, err := account.sequencer.Call(ctx, call, "latest")
if err != nil {
return nil, err
}
return utils.HexArrToFelt(resp)
}
return nil, ErrUnsupportedAccount
}
Expand Down
15 changes: 8 additions & 7 deletions accountgw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,23 @@ import (
"testing"
"time"

"github.com/NethermindEth/starknet.go/rpc"
"github.com/NethermindEth/starknet.go/types"
"github.com/NethermindEth/starknet.go/utils"
)

type TestAccountType struct {
PrivateKey string `json:"private_key"`
PublicKey string `json:"public_key"`
Address string `json:"address"`
Transactions []types.FunctionCall `json:"transactions,omitempty"`
PrivateKey string `json:"private_key"`
PublicKey string `json:"public_key"`
Address string `json:"address"`
Transactions []rpc.FunctionCall `json:"transactions,omitempty"`
}

func TestGatewayAccount_EstimateAndExecute(t *testing.T) {
testConfig := beforeGatewayEach(t)
type testSetType struct {
ExecuteCalls []types.FunctionCall
QueryCall types.FunctionCall
QueryCall rpc.FunctionCall
}

testSet := map[string][]testSetType{
Expand All @@ -30,7 +31,7 @@ func TestGatewayAccount_EstimateAndExecute(t *testing.T) {
EntryPointSelector: types.GetSelectorFromNameFelt("increment"),
ContractAddress: utils.TestHexToFelt(t, testConfig.CounterAddress),
}},
QueryCall: types.FunctionCall{
QueryCall: rpc.FunctionCall{
EntryPointSelector: types.GetSelectorFromNameFelt("get_count"),
ContractAddress: utils.TestHexToFelt(t, testConfig.CounterAddress),
},
Expand All @@ -40,7 +41,7 @@ func TestGatewayAccount_EstimateAndExecute(t *testing.T) {
EntryPointSelector: types.GetSelectorFromNameFelt("increment"),
ContractAddress: utils.TestHexToFelt(t, testConfig.CounterAddress),
}},
QueryCall: types.FunctionCall{
QueryCall: rpc.FunctionCall{
EntryPointSelector: types.GetSelectorFromNameFelt("get_count"),
ContractAddress: utils.TestHexToFelt(t, testConfig.CounterAddress),
},
Expand Down
3 changes: 2 additions & 1 deletion contracts/sessionkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
starknetgo "github.com/NethermindEth/starknet.go"
"github.com/NethermindEth/starknet.go/gateway"
"github.com/NethermindEth/starknet.go/plugins/xsessions"
"github.com/NethermindEth/starknet.go/rpc"
"github.com/NethermindEth/starknet.go/types"
"github.com/NethermindEth/starknet.go/utils"
)
Expand Down Expand Up @@ -138,7 +139,7 @@ func (ap *AccountManager) ExecuteWithGateway(counterAddress *felt.Felt, selector
return tx.TransactionHash.String(), nil
}

func (ap *AccountManager) CallWithGateway(call types.FunctionCall, provider *gateway.GatewayProvider) ([]string, error) {
func (ap *AccountManager) CallWithGateway(call rpc.FunctionCall, provider *gateway.GatewayProvider) ([]*felt.Felt, error) {
// shim in the keystore. while weird and awkward, it's functionally ok because
// 1. account manager doesn't seem to be used any where
// 2. the account that is created below is scoped to this func
Expand Down
3 changes: 2 additions & 1 deletion gateway/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ import (
"strings"

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/starknet.go/rpc"
"github.com/NethermindEth/starknet.go/types"
)

func (sg *Gateway) AccountNonce(ctx context.Context, address *felt.Felt) (*big.Int, error) {
resp, err := sg.Call(ctx, types.FunctionCall{
resp, err := sg.Call(ctx, rpc.FunctionCall{
ContractAddress: address,
EntryPointSelector: types.GetSelectorFromNameFelt("get_nonce"),
}, "")
Expand Down
14 changes: 7 additions & 7 deletions gateway/starknet.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (f FunctionCall) MarshalJSON() ([]byte, error) {
/*
'call_contract' wrapper and can accept a blockId in the hash or height format
*/
func (sg *Gateway) Call(ctx context.Context, call types.FunctionCall, blockHashOrTag string) ([]string, error) {
func (sg *Gateway) Call(ctx context.Context, call rpc.FunctionCall, blockHashOrTag string) ([]string, error) {
gc := GatewayFunctionCall{
FunctionCall: FunctionCall(call),
}
Expand Down Expand Up @@ -220,12 +220,12 @@ func (sg *Gateway) Declare(ctx context.Context, contract rpc.ContractClass, decl
// }

type DeclareRequest struct {
Type string `json:"type"`
SenderAddress *felt.Felt `json:"sender_address"`
Version string `json:"version"`
MaxFee string `json:"max_fee"`
Nonce string `json:"nonce"`
Signature []string `json:"signature"`
Type string `json:"type"`
SenderAddress *felt.Felt `json:"sender_address"`
Version string `json:"version"`
MaxFee string `json:"max_fee"`
Nonce string `json:"nonce"`
Signature []string `json:"signature"`
ContractClass rpc.ContractClass `json:"contract_class"`
}

Expand Down
4 changes: 2 additions & 2 deletions gateway/starknet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,12 @@ func TestCall(t *testing.T) {
testConfig := beforeEach(t)

type testSetType struct {
Call types.FunctionCall
Call rpc.FunctionCall
}
testSet := map[string][]testSetType{
"devnet": {
{
Call: types.FunctionCall{
Call: rpc.FunctionCall{
ContractAddress: utils.TestHexToFelt(t, counterAddress),
EntryPointSelector: types.GetSelectorFromNameFelt("get_count"),
Calldata: []*felt.Felt{},
Expand Down
8 changes: 2 additions & 6 deletions rpc/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@ import (
)

// Call a starknet function without creating a StarkNet transaction.
func (provider *Provider) Call(ctx context.Context, request FunctionCall, blockID BlockID) ([]string, error) {
func (provider *Provider) Call(ctx context.Context, request FunctionCall, blockID BlockID) ([]*felt.Felt, error) {

if len(request.Calldata) == 0 {
request.Calldata = make([]*felt.Felt, 0)
}
var result []string
var result []*felt.Felt
if err := do(ctx, provider.c, "starknet_call", &result, request, blockID); err != nil {
switch {
case errors.Is(err, ErrContractNotFound):
return nil, ErrContractNotFound
case errors.Is(err, ErrInvalidMessageSelector):
return nil, ErrInvalidMessageSelector
case errors.Is(err, ErrInvalidCallData):
return nil, ErrInvalidCallData
case errors.Is(err, ErrContractError):
return nil, ErrContractError
case errors.Is(err, ErrBlockNotFound):
Expand Down
35 changes: 7 additions & 28 deletions rpc/call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package rpc

import (
"context"
"regexp"
"testing"

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/starknet.go/types"
"github.com/NethermindEth/starknet.go/utils"
"github.com/test-go/testify/require"
)

// TestCall tests Call
Expand All @@ -17,7 +17,7 @@ func TestCall(t *testing.T) {
type testSetType struct {
FunctionCall FunctionCall
BlockID BlockID
ExpectedPatternResult string
ExpectedPatternResult *felt.Felt
}
testSet := map[string][]testSetType{
"devnet": {
Expand All @@ -29,7 +29,7 @@ func TestCall(t *testing.T) {
Calldata: []*felt.Felt{},
},
BlockID: WithBlockTag("latest"),
ExpectedPatternResult: "^0x[0-9a-f]+$",
ExpectedPatternResult: utils.TestHexToFelt(t, "0x6574686572"),
},
},
"mock": {
Expand All @@ -40,7 +40,7 @@ func TestCall(t *testing.T) {
Calldata: []*felt.Felt{},
},
BlockID: WithBlockTag("latest"),
ExpectedPatternResult: "^0x12$",
ExpectedPatternResult: utils.TestHexToFelt(t, "0xdeadbeef"),
},
},
"testnet": {
Expand All @@ -51,25 +51,7 @@ func TestCall(t *testing.T) {
Calldata: []*felt.Felt{},
},
BlockID: WithBlockTag("latest"),
ExpectedPatternResult: "^0x12$",
},
{
FunctionCall: FunctionCall{
ContractAddress: utils.TestHexToFelt(t, TestNetETHAddress),
EntryPointSelector: types.GetSelectorFromNameFelt("balanceOf"),
Calldata: []*felt.Felt{utils.TestHexToFelt(t, "0x0207aCC15dc241e7d167E67e30E769719A727d3E0fa47f9E187707289885Dfde")},
},
BlockID: WithBlockTag("latest"),
ExpectedPatternResult: "^0x[0-9a-f]+$",
},
{
FunctionCall: FunctionCall{
ContractAddress: utils.TestHexToFelt(t, TestNetAccount032Address),
EntryPointSelector: types.GetSelectorFromNameFelt("get_nonce"),
Calldata: []*felt.Felt{},
},
BlockID: WithBlockTag("latest"),
ExpectedPatternResult: "^0x[0-9a-f]+$",
ExpectedPatternResult: utils.TestHexToFelt(t, "0x12"),
},
},
"mainnet": {
Expand All @@ -80,7 +62,7 @@ func TestCall(t *testing.T) {
Calldata: []*felt.Felt{},
},
BlockID: WithBlockTag("latest"),
ExpectedPatternResult: "^0x12$",
ExpectedPatternResult: utils.TestHexToFelt(t, "0x12"),
},
},
}[testEnv]
Expand All @@ -99,10 +81,7 @@ func TestCall(t *testing.T) {
if len(output) == 0 {
t.Fatal("should return an output")
}
match, err := regexp.Match(test.ExpectedPatternResult, []byte(output[0]))
if err != nil || !match {
t.Fatalf("checking output(%v) expecting %s, got: %v", err, test.ExpectedPatternResult, output[0])
}
require.Equal(t, test.ExpectedPatternResult, output[0])

}
}
7 changes: 5 additions & 2 deletions rpc/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,11 @@ func mock_starknet_call(result interface{}, method string, args ...interface{})
fmt.Printf("args: %d\n", len(args))
return errWrongArgs
}
output := []string{"0x12"}
outputContent, _ := json.Marshal(output)
out, err := new(felt.Felt).SetString("0xdeadbeef")
if err != nil {
return err
}
outputContent, _ := json.Marshal([]*felt.Felt{out})
json.Unmarshal(outputContent, r)
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion rpc/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type api interface {
BlockTransactionCount(ctx context.Context, blockID BlockID) (uint64, error)
BlockWithTxHashes(ctx context.Context, blockID BlockID) (interface{}, error)
BlockWithTxs(ctx context.Context, blockID BlockID) (interface{}, error)
Call(ctx context.Context, call FunctionCall, block BlockID) ([]string, error)
Call(ctx context.Context, call FunctionCall, block BlockID) ([]*felt.Felt, error)
ChainID(ctx context.Context) (string, error)
Class(ctx context.Context, blockID BlockID, classHash string) (*ContractClass, error)
ClassAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*ContractClass, error)
Expand Down

0 comments on commit 53e7d9f

Please sign in to comment.