Skip to content

Commit

Permalink
Merge pull request #591 from BoostryJP/feature/#527
Browse files Browse the repository at this point in the history
Replace sync func to async func
  • Loading branch information
YoshihitoAso authored Jan 29, 2024
2 parents e931015 + 29c377f commit 044c185
Show file tree
Hide file tree
Showing 354 changed files with 11,896 additions and 7,334 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 23.11.0
rev: 24.1.0
hooks:
- id: black
language_version: python3.11
78 changes: 65 additions & 13 deletions app/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,70 @@
SPDX-License-Identifier: Apache-2.0
"""

from typing import Annotated

from fastapi import Depends
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Session, sessionmaker

from config import DATABASE_SCHEMA, DATABASE_URL, DB_ECHO

options = {
"pool_recycle": 3600,
"pool_size": 10,
"pool_timeout": 30,
"pool_pre_ping": True,
"max_overflow": 30,
"echo": DB_ECHO,
}
engine = create_engine(DATABASE_URL, **options)

def get_engine(uri: str):
options = {
"pool_recycle": 3600,
"pool_size": 10,
"pool_timeout": 30,
"pool_pre_ping": True,
"max_overflow": 30,
"echo": DB_ECHO,
}
return create_engine(uri, **options)


def get_async_engine(uri: str):
options = {
"pool_recycle": 3600,
"pool_size": 10,
"pool_timeout": 30,
"pool_pre_ping": True,
"max_overflow": 30,
"echo": DB_ECHO,
}
return create_async_engine(uri, **options)


def get_batch_async_engine(uri: str):
options = {
"pool_pre_ping": True,
"echo": False,
}
return create_async_engine(uri, **options)


# Create Engine
engine = get_engine(DATABASE_URL)
async_engine = get_async_engine(DATABASE_URL)
batch_async_engine = get_batch_async_engine(DATABASE_URL)

# Create Session Maker
SessionLocal = sessionmaker(autocommit=False, autoflush=True, bind=engine)
AsyncSessionLocal = async_sessionmaker(
autocommit=False,
autoflush=True,
expire_on_commit=False,
bind=async_engine,
class_=AsyncSession,
)
BatchAsyncSessionLocal = async_sessionmaker(
autocommit=False,
autoflush=True,
expire_on_commit=False,
bind=batch_async_engine,
class_=AsyncSession,
)


def db_session():
Expand All @@ -44,11 +90,17 @@ def db_session():
db.close()


async def db_async_session():
db = AsyncSessionLocal()
try:
yield db
finally:
await db.close()


DBSession = Annotated[Session, Depends(db_session)]
DBAsyncSession = Annotated[AsyncSession, Depends(db_async_session)]


def get_db_schema():
if DATABASE_SCHEMA and engine.name != "mysql":
return DATABASE_SCHEMA
else:
return None
return DATABASE_SCHEMA
1 change: 1 addition & 0 deletions app/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SPDX-License-Identifier: Apache-2.0
"""

from fastapi import status

from app.utils.contract_error_code import REVERT_CODE_MAP, error_code_msg
Expand Down
7 changes: 6 additions & 1 deletion app/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SPDX-License-Identifier: Apache-2.0
"""

import logging
import sys
import urllib
Expand Down Expand Up @@ -128,7 +129,11 @@ def output_access_log(req: Request, res: Response, request_start_time: datetime)


def __auth_format(req: Request, address: str, msg: str):
return AUTH_FORMAT % (req.client.host, address, msg)
if req.client is None:
_host = ""
else:
_host = req.client.host
return AUTH_FORMAT % (_host, address, msg)


def __get_url(req: Request):
Expand Down
1 change: 1 addition & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SPDX-License-Identifier: Apache-2.0
"""

from datetime import datetime

from fastapi import FastAPI, Request
Expand Down
1 change: 1 addition & 0 deletions app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SPDX-License-Identifier: Apache-2.0
"""

from typing import Annotated, Any

from pydantic import WrapValidator
Expand Down
1 change: 1 addition & 0 deletions app/model/blockchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SPDX-License-Identifier: Apache-2.0
"""

from .e2e_messaging import E2EMessaging
from .exchange import IbetExchangeInterface, IbetSecurityTokenEscrow
from .freeze_log import FreezeLogContract
Expand Down
35 changes: 21 additions & 14 deletions app/model/blockchain/e2e_messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SPDX-License-Identifier: Apache-2.0
"""

import base64
import json
import secrets
Expand All @@ -27,7 +28,7 @@
from web3.exceptions import TimeExhausted

from app.exceptions import ContractRevertError, SendTransactionError
from app.utils.contract_utils import ContractUtils
from app.utils.contract_utils import AsyncContractUtils
from config import (
AWS_KMS_GENERATE_RANDOM_ENABLED,
AWS_REGION_NAME,
Expand All @@ -42,23 +43,27 @@ class E2EMessaging:
def __init__(self, contract_address: str):
self.contract_address = contract_address

def send_message(
self, to_address: str, message: str, tx_from: str, private_key: str
async def send_message(
self, to_address: str, message: str, tx_from: str, private_key: bytes
):
"""Send Message"""
contract = ContractUtils.get_contract(
contract = AsyncContractUtils.get_contract(
contract_name="E2EMessaging", contract_address=self.contract_address
)
try:
tx = contract.functions.sendMessage(to_address, message).build_transaction(
tx = await contract.functions.sendMessage(
to_address, message
).build_transaction(
{
"chainId": CHAIN_ID,
"from": tx_from,
"gas": TX_GAS_LIMIT,
"gasPrice": 0,
}
)
tx_hash, tx_receipt = ContractUtils.send_transaction(tx, private_key)
tx_hash, tx_receipt = await AsyncContractUtils.send_transaction(
tx, private_key
)
return tx_hash, tx_receipt
except ContractRevertError:
raise
Expand All @@ -67,14 +72,14 @@ def send_message(
except Exception as err:
raise SendTransactionError(err)

def send_message_external(
async def send_message_external(
self,
to_address: str,
_type: str,
message_org: str,
to_rsa_public_key: str,
tx_from: str,
private_key: str,
private_key: bytes,
):
"""Send Message(Format message for external system)"""

Expand Down Expand Up @@ -110,23 +115,23 @@ def send_message_external(
message = json.dumps(message_dict)

# Send message
tx_hash, tx_receipt = E2EMessaging(self.contract_address).send_message(
tx_hash, tx_receipt = await E2EMessaging(self.contract_address).send_message(
to_address=to_address,
message=message,
tx_from=tx_from,
private_key=private_key,
)
return tx_hash, tx_receipt

def set_public_key(
self, public_key: str, key_type: str, tx_from: str, private_key: str
async def set_public_key(
self, public_key: str, key_type: str, tx_from: str, private_key: bytes
):
"""Set Public Key"""
contract = ContractUtils.get_contract(
contract = AsyncContractUtils.get_contract(
contract_name="E2EMessaging", contract_address=self.contract_address
)
try:
tx = contract.functions.setPublicKey(
tx = await contract.functions.setPublicKey(
public_key, key_type
).build_transaction(
{
Expand All @@ -136,7 +141,9 @@ def set_public_key(
"gasPrice": 0,
}
)
tx_hash, tx_receipt = ContractUtils.send_transaction(tx, private_key)
tx_hash, tx_receipt = await AsyncContractUtils.send_transaction(
tx, private_key
)
return tx_hash, tx_receipt
except ContractRevertError:
raise
Expand Down
19 changes: 10 additions & 9 deletions app/model/blockchain/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
SPDX-License-Identifier: Apache-2.0
"""

from web3.exceptions import TimeExhausted

from app.exceptions import ContractRevertError, SendTransactionError
from app.model.blockchain.tx_params.ibet_security_token_escrow import (
ApproveTransferParams,
)
from app.utils.contract_utils import ContractUtils
from app.utils.contract_utils import AsyncContractUtils
from config import CHAIN_ID, TX_GAS_LIMIT


Expand All @@ -32,18 +33,18 @@ class IbetExchangeInterface:
def __init__(
self, contract_address: str, contract_name: str = "IbetExchangeInterface"
):
self.exchange_contract = ContractUtils.get_contract(
self.exchange_contract = AsyncContractUtils.get_contract(
contract_name=contract_name, contract_address=contract_address
)

def get_account_balance(self, account_address: str, token_address: str):
async def get_account_balance(self, account_address: str, token_address: str):
"""Get account balance
:param account_address: account address
:param token_address: token address
:return: account balance
"""
balance = ContractUtils.call_function(
balance = await AsyncContractUtils.call_function(
contract=self.exchange_contract,
function_name="balanceOf",
args=(
Expand All @@ -52,7 +53,7 @@ def get_account_balance(self, account_address: str, token_address: str):
),
default_returns=0,
)
commitment = ContractUtils.call_function(
commitment = await AsyncContractUtils.call_function(
contract=self.exchange_contract,
function_name="commitmentOf",
args=(
Expand All @@ -73,12 +74,12 @@ def __init__(self, contract_address: str):
contract_address=contract_address, contract_name="IbetSecurityTokenEscrow"
)

def approve_transfer(
self, data: ApproveTransferParams, tx_from: str, private_key: str
async def approve_transfer(
self, data: ApproveTransferParams, tx_from: str, private_key: bytes
):
"""Approve Transfer"""
try:
tx = self.exchange_contract.functions.approveTransfer(
tx = await self.exchange_contract.functions.approveTransfer(
data.escrow_id, data.data
).build_transaction(
{
Expand All @@ -88,7 +89,7 @@ def approve_transfer(
"gasPrice": 0,
}
)
tx_hash, tx_receipt = ContractUtils.send_transaction(
tx_hash, tx_receipt = await AsyncContractUtils.send_transaction(
transaction=tx, private_key=private_key
)
return tx_hash, tx_receipt
Expand Down
Loading

0 comments on commit 044c185

Please sign in to comment.