Skip to content

Commit

Permalink
Merge branch 'estimateMsgFee' into test_merge_rpcv04
Browse files Browse the repository at this point in the history
  • Loading branch information
rianhughes committed Aug 21, 2023
2 parents e409e6e + 5ed6674 commit 5a71ee9
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 0 deletions.
17 changes: 17 additions & 0 deletions rpc/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,20 @@ func (provider *Provider) EstimateFee(ctx context.Context, requests []Broadcaste
}
return raw, nil
}

// EstimateMessageFee estimates the L2 fee of a message sent on L1
func (provider *Provider) EstimateMessageFee(ctx context.Context, msg MsgFromL1, blockID BlockID) (*FeeEstimate, error) {
var raw FeeEstimate
if err := do(ctx, provider.c, "starknet_estimateMessageFee", &raw, msg, blockID); err != nil {
switch {
case errors.Is(err, ErrContractNotFound):
return nil, ErrContractNotFound
case errors.Is(err, ErrContractError):
return nil, ErrContractError
case errors.Is(err, ErrBlockNotFound):
return nil, ErrBlockNotFound
}
return nil, err
}
return &raw, nil
}
38 changes: 38 additions & 0 deletions rpc/contract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/starknet.go/utils"
"github.com/test-go/testify/require"
)

// TestClassAt tests code for a class.
Expand Down Expand Up @@ -261,3 +262,40 @@ func TestNonce(t *testing.T) {
}
}
}

// TestEstimateMessageFee tests EstimateMesssageFee
func TestEstimateMessageFee(t *testing.T) {
testConfig := beforeEach(t)

type testSetType struct {
MsgFromL1
BlockID
ExpectedFeeEst FeeEstimate
}
testSet := map[string][]testSetType{
"mock": {
{
MsgFromL1: MsgFromL1{FromAddress: &felt.Zero, ToAddress: &felt.Zero, Selector: &felt.Zero, Payload: []*felt.Felt{&felt.Zero}},
BlockID: BlockID{Tag: "latest"},
ExpectedFeeEst: FeeEstimate{
GasConsumed: NumAsHex("0x1"),
GasPrice: NumAsHex("0x2"),
OverallFee: NumAsHex("0x3"),
},
},
},
"testnet": {},
"mainnet": {},
}[testEnv]

for _, test := range testSet {
spy := NewSpy(testConfig.provider.c)
testConfig.provider.c = spy
value, err := testConfig.provider.EstimateMessageFee(context.Background(), test.MsgFromL1, test.BlockID)
if err != nil {
t.Fatal(err)
}
require.Equal(t, *value, test.ExpectedFeeEst)

}
}
32 changes: 32 additions & 0 deletions rpc/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ func (r *rpcMock) CallContext(ctx context.Context, result interface{}, method st
return mock_starknet_addInvokeTransaction(result, method, args...)
case "starknet_estimateFee":
return mock_starknet_estimateFee(result, method, args...)
case "starknet_estimateMessageFee":
return mock_starknet_estimateMessageFee(result, method, args...)
default:
return errNotFound
}
Expand Down Expand Up @@ -437,6 +439,36 @@ func mock_starknet_estimateFee(result interface{}, method string, args ...interf
return nil
}

func mock_starknet_estimateMessageFee(result interface{}, method string, args ...interface{}) error {
r, ok := result.(*json.RawMessage)
if !ok {
return errWrongType
}
if len(args) != 2 {
fmt.Printf("args: %d\n", len(args))
return errWrongArgs
}
_, ok = args[0].(MsgFromL1)
if !ok {
fmt.Printf("args[0] should be MsgFromL1, got %T\n", args[0])
return errWrongArgs
}
_, ok = args[1].(BlockID)
if !ok {
fmt.Printf("args[1] should be *blockID, got %T\n", args[1])
return errWrongArgs
}

output := FeeEstimate{
GasConsumed: NumAsHex("0x1"),
GasPrice: NumAsHex("0x2"),
OverallFee: NumAsHex("0x3"),
}
outputContent, _ := json.Marshal(output)
json.Unmarshal(outputContent, r)
return nil
}

func mock_starknet_addInvokeTransaction(result interface{}, method string, args ...interface{}) error {
r, ok := result.(*json.RawMessage)
if !ok {
Expand Down
1 change: 1 addition & 0 deletions rpc/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type api interface {
ClassAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*ContractClass, error)
ClassHashAt(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*string, error)
EstimateFee(ctx context.Context, requests []BroadcastedTransaction, blockID BlockID) ([]FeeEstimate, error)
EstimateMessageFee(ctx context.Context, msg MsgFromL1, blockID BlockID) (*FeeEstimate, error)
Events(ctx context.Context, input EventsInput) (*EventsOutput, error)
Nonce(ctx context.Context, blockID BlockID, contractAddress *felt.Felt) (*string, error)
StateUpdate(ctx context.Context, blockID BlockID) (*StateUpdateOutput, error)
Expand Down
11 changes: 11 additions & 0 deletions rpc/types_transaction_receipt.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ type MsgToL1 struct {
Payload []*felt.Felt `json:"payload"`
}

type MsgFromL1 struct {
// FromAddress The address of the L1 contract sending the message
FromAddress *felt.Felt `json:"from_address"`
// ToAddress The target L2 address the message is sent to
ToAddress *felt.Felt `json:"to_address"`
// EntryPointSelector The selector of the l1_handler in invoke in the target contract
Selector *felt.Felt `json:"entry_point_selector"`
//Payload The payload of the message
Payload []*felt.Felt `json:"payload"`
}

type UnknownTransactionReceipt struct{ TransactionReceipt }

func (tr *UnknownTransactionReceipt) UnmarshalJSON(data []byte) error {
Expand Down

0 comments on commit 5a71ee9

Please sign in to comment.