Skip to content
This repository has been archived by the owner on Dec 15, 2023. It is now read-only.

Fix function calling #426

Merged
merged 2 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions starknet_devnet/blueprints/feeder_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Feeder gateway routes.
"""

from typing import Type

from flask import Blueprint, Response, jsonify, request
from marshmallow import ValidationError
from starkware.starknet.services.api.feeder_gateway.request_objects import (
Expand All @@ -21,6 +23,7 @@
Transaction,
)
from starkware.starkware_utils.error_handling import StarkErrorCode
from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass
from werkzeug.datastructures import MultiDict

from starknet_devnet.blueprints.rpc.structures.types import BlockId
Expand All @@ -34,7 +37,7 @@
feeder_gateway = Blueprint("feeder_gateway", __name__, url_prefix="/feeder_gateway")


def validate_request(data: bytes, cls, many=False):
def validate_request(data: bytes, cls: Type[ValidatedMarshmallowDataclass], many=False):
"""Ensure `data` is valid Starknet function call. Returns an object of type specified with `cls`."""
try:
return cls.Schema().loads(data, many=many)
Expand Down Expand Up @@ -149,13 +152,14 @@ async def call_contract():
"""

block_id = _get_block_id(request.args)
data = request.get_data() # better than request.data in some edge cases

try:
# version 1
call_specifications = validate_request(request.data, CallFunction)
call_specifications = validate_request(data, CallFunction)
except StarknetDevnetException:
# version 0
call_specifications = validate_request(request.data, InvokeFunction)
call_specifications = validate_request(data, InvokeFunction)

result_dict = await state.starknet_wrapper.call(call_specifications, block_id)
return jsonify(result_dict)
Expand Down Expand Up @@ -330,11 +334,12 @@ async def get_state_update():
@feeder_gateway.route("/estimate_fee", methods=["POST"])
async def estimate_fee():
"""Returns the estimated fee for a transaction."""
data = request.get_data()

try:
transaction = validate_request(request.data, Transaction) # version 1
transaction = validate_request(data, Transaction) # version 1
except StarknetDevnetException:
transaction = validate_request(request.data, InvokeFunction) # version 0
transaction = validate_request(data, InvokeFunction) # version 0

block_id = _get_block_id(request.args)
skip_validate = _get_skip_validate(request.args)
Expand All @@ -351,10 +356,10 @@ async def estimate_fee_bulk():

try:
# version 1
transactions = validate_request(request.data, Transaction, many=True)
transactions = validate_request(request.get_data(), Transaction, many=True)
except StarknetDevnetException:
# version 0
transactions = validate_request(request.data, InvokeFunction, many=True)
transactions = validate_request(request.get_data(), InvokeFunction, many=True)

block_id = _get_block_id(request.args)
skip_validate = _get_skip_validate(request.args)
Expand All @@ -370,7 +375,7 @@ async def estimate_fee_bulk():
@feeder_gateway.route("/simulate_transaction", methods=["POST"])
async def simulate_transaction():
"""Returns the estimated fee for a transaction."""
transaction = validate_request(request.data, AccountTransaction)
transaction = validate_request(request.get_data(), AccountTransaction)
block_id = _get_block_id(request.args)
skip_validate = _get_skip_validate(request.args)

Expand Down Expand Up @@ -404,6 +409,6 @@ async def estimate_message_fee():

block_id = _get_block_id(request.args)

call = validate_request(request.data, CallL1Handler)
call = validate_request(request.get_data(), CallL1Handler)
fee_estimation = await state.starknet_wrapper.estimate_message_fee(call, block_id)
return jsonify(fee_estimation)
18 changes: 14 additions & 4 deletions starknet_devnet/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
"""

import asyncio
import json
import os
import sys

from flask import Flask, jsonify
from flask_cors import CORS
from gunicorn.app.base import BaseApplication
from starkware.starkware_utils.error_handling import StarkException
from starkware.starkware_utils.error_handling import StarkErrorCode, StarkException

from .blueprints.base import base
from .blueprints.feeder_gateway import feeder_gateway
Expand Down Expand Up @@ -94,9 +95,9 @@ def main():

asyncio.run(state.starknet_wrapper.initialize())

main_pid = os.getpid()
print(f" * Listening on http://{args.host}:{args.port}/ (Press CTRL+C to quit)")
try:
print(f" * Listening on http://{args.host}:{args.port}/ (Press CTRL+C to quit)")
main_pid = os.getpid()
GunicornServer(app, args).run()
except KeyboardInterrupt:
pass
Expand All @@ -108,14 +109,23 @@ def main():


@app.errorhandler(StarkException)
def handle(error: StarkException):
def handle_stark_exception(error: StarkException):
"""Handles the error and responds in JSON."""
return {
"message": error.message,
"code": str(error.code),
}, error.status_code


@app.errorhandler(json.decoder.JSONDecodeError)
def handle_json_decode_error(error: json.decoder.JSONDecodeError):
"""Handles json error"""
return {
"message": f"Error while decoding JSON: {error}",
"code": str(StarkErrorCode.MALFORMED_REQUEST),
}, 500


@app.route("/api", methods=["GET"])
def api():
"""Return available endpoints."""
Expand Down
13 changes: 11 additions & 2 deletions starknet_devnet/starknet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,13 +633,22 @@ async def __get_query_state(self, block_id: BlockId = DEFAULT_BLOCK_ID):
)

async def call(
self, transaction: CallFunction, block_id: BlockId = DEFAULT_BLOCK_ID
self,
transaction: Union[CallFunction, InvokeFunction],
block_id: BlockId = DEFAULT_BLOCK_ID,
):
"""Perform call according to specifications in `transaction`."""
state = await self.__get_query_state(block_id)

# property name different since starknet 0.11
address = (
transaction.contract_address
if isinstance(transaction, CallFunction)
else transaction.sender_address
)

call_info = await state.copy().execute_entry_point_raw(
contract_address=transaction.contract_address,
contract_address=address,
selector=transaction.entry_point_selector,
calldata=transaction.calldata,
caller_address=0,
Expand Down
31 changes: 31 additions & 0 deletions test/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import requests
from starkware.starknet.definitions.error_codes import StarknetErrorCode
from starkware.starkware_utils.error_handling import StarkErrorCode

from starknet_devnet.constants import DEFAULT_GAS_PRICE
from starknet_devnet.server import app
Expand Down Expand Up @@ -430,3 +431,33 @@ def test_get_transaction_receipt_with_tx_hash_0():
resp = get_transaction_receipt_test_client("0")
assert resp.json["message"].startswith(INVALID_TRANSACTION_HASH_MESSAGE_PREFIX)
assert resp.status_code == 500


@pytest.mark.parametrize("address_property", ["contract_address", "sender_address"])
def test_calling_function_with_different_address_properties(address_property: str):
"""In starknet 0.11 contract_address was changed to sender_address"""
dummy_uninitialized_address = "0x01"
resp = app.test_client().post(
"/feeder_gateway/call_contract",
content_type="a",
data=json.dumps(
{
"entry_point_selector": "0x0",
"calldata": [],
"signature": [],
address_property: dummy_uninitialized_address,
}
),
)

assert resp.status_code == 500
assert resp.is_json
assert resp.json.get("code") == str(StarknetErrorCode.UNINITIALIZED_CONTRACT)


def test_calling_without_body():
"""Test graceful failing without body"""
resp = app.test_client().post("/feeder_gateway/call_contract")
assert resp.status_code == 500
assert resp.is_json
assert resp.json.get("code") == str(StarkErrorCode.MALFORMED_REQUEST)