Skip to content

Commit

Permalink
Switch to asyncpg
Browse files Browse the repository at this point in the history
  • Loading branch information
YoshihitoAso committed Jan 7, 2025
1 parent 4480c56 commit 1029107
Show file tree
Hide file tree
Showing 32 changed files with 685 additions and 578 deletions.
6 changes: 3 additions & 3 deletions app/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Session, sessionmaker

from config import DATABASE_SCHEMA, DATABASE_URL, DB_ECHO
from config import ASYNC_DATABASE_URL, DATABASE_SCHEMA, DATABASE_URL, DB_ECHO


def get_engine(uri: str):
Expand Down Expand Up @@ -63,8 +63,8 @@ def get_batch_async_engine(uri: str):

# Create Engine
engine = get_engine(DATABASE_URL)
async_engine = get_async_engine(DATABASE_URL)
batch_async_engine = get_batch_async_engine(DATABASE_URL)
async_engine = get_async_engine(ASYNC_DATABASE_URL)
batch_async_engine = get_batch_async_engine(ASYNC_DATABASE_URL)

# Create Session Maker
SessionLocal = sessionmaker(autocommit=False, autoflush=True, bind=engine)
Expand Down
2 changes: 1 addition & 1 deletion app/model/db/batch_issue_redeem.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ class BatchIssueRedeem(Base):
# target account
account_address: Mapped[str] = mapped_column(String(42), nullable=False)
# amount
amount: Mapped[str] = mapped_column(BigInteger, nullable=False)
amount: Mapped[int] = mapped_column(BigInteger, nullable=False)
# processing status (pending:0, succeeded:1, failed:2)
status: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
23 changes: 14 additions & 9 deletions app/routers/issuer/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
from fastapi import APIRouter, Header, Path, Query, Request
from fastapi.exceptions import HTTPException
from sqlalchemy import and_, asc, desc, func, select
from sqlalchemy.exc import IntegrityError as SAIntegrityError, OperationalError
from sqlalchemy.exc import (
DBAPIError,
IntegrityError as SAIntegrityError,
OperationalError,
)

import config
from app.database import DBAsyncSession
Expand Down Expand Up @@ -516,7 +520,7 @@ async def generate_issuer_auth_token(
raise AuthTokenAlreadyExistsError()
# Update auth token
auth_token.auth_token = hashed_token
auth_token.usage_start = current_datetime_utc
auth_token.usage_start = current_datetime_utc.replace(tzinfo=None)
auth_token.valid_duration = data.valid_duration
await db.merge(auth_token)
await db.commit()
Expand All @@ -525,7 +529,7 @@ async def generate_issuer_auth_token(
auth_token = AuthToken()
auth_token.issuer_address = issuer_address
auth_token.auth_token = hashed_token
auth_token.usage_start = current_datetime_utc
auth_token.usage_start = current_datetime_utc.replace(tzinfo=None)
auth_token.valid_duration = data.valid_duration
db.add(auth_token)
await db.commit()
Expand Down Expand Up @@ -631,7 +635,7 @@ async def create_child_account(
.with_for_update(nowait=True)
)
).first()
except OperationalError:
except (OperationalError, DBAPIError):
await db.rollback()
await db.close()
raise ServiceUnavailableError(
Expand Down Expand Up @@ -728,7 +732,7 @@ async def create_child_account_in_batch(
.with_for_update(nowait=True)
)
).first()
except OperationalError:
except (OperationalError, DBAPIError):
await db.rollback()
await db.close()
raise ServiceUnavailableError(
Expand Down Expand Up @@ -833,30 +837,31 @@ async def list_all_child_account(
)
stmt = stmt.where(
IDXPersonalInfo.created
>= local_tz.localize(_created_from).astimezone(utc_tz)
>= local_tz.localize(_created_from).astimezone(utc_tz).replace(tzinfo=None)
)
if get_query.created_to:
_created_to = datetime.strptime(
get_query.created_to + ".999999", "%Y-%m-%d %H:%M:%S.%f"
)
stmt = stmt.where(
IDXPersonalInfo.created <= local_tz.localize(_created_to).astimezone(utc_tz)
IDXPersonalInfo.created
<= local_tz.localize(_created_to).astimezone(utc_tz).replace(tzinfo=None)
)
if get_query.modified_from:
_modified_from = datetime.strptime(
get_query.modified_from + ".000000", "%Y-%m-%d %H:%M:%S.%f"
)
stmt = stmt.where(
IDXPersonalInfo.modified
>= local_tz.localize(_modified_from).astimezone(utc_tz)
>= local_tz.localize(_modified_from).astimezone(utc_tz).replace(tzinfo=None)
)
if get_query.modified_to:
_modified_to = datetime.strptime(
get_query.modified_to + ".999999", "%Y-%m-%d %H:%M:%S.%f"
)
stmt = stmt.where(
IDXPersonalInfo.modified
<= local_tz.localize(_modified_to).astimezone(utc_tz)
<= local_tz.localize(_modified_to).astimezone(utc_tz).replace(tzinfo=None)
)

count = await db.scalar(select(func.count()).select_from(stmt.subquery()))
Expand Down
60 changes: 34 additions & 26 deletions app/routers/issuer/bond.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ async def issue_bond_token(
_update_token = UpdateToken()
_update_token.token_address = contract_address
_update_token.issuer_address = issuer_address
_update_token.type = TokenType.IBET_STRAIGHT_BOND.value
_update_token.type = TokenType.IBET_STRAIGHT_BOND
_update_token.arguments = token_dict
_update_token.status = 0 # pending
_update_token.trigger = "Issue"
Expand Down Expand Up @@ -340,7 +340,7 @@ async def issue_bond_token(

# Register token data
_token = Token()
_token.type = TokenType.IBET_STRAIGHT_BOND.value
_token.type = TokenType.IBET_STRAIGHT_BOND
_token.tx_hash = tx_hash
_token.issuer_address = issuer_address
_token.token_address = contract_address
Expand All @@ -353,10 +353,10 @@ async def issue_bond_token(
operation_log = TokenUpdateOperationLog()
operation_log.token_address = contract_address
operation_log.issuer_address = issuer_address
operation_log.type = TokenType.IBET_STRAIGHT_BOND.value
operation_log.type = TokenType.IBET_STRAIGHT_BOND
operation_log.arguments = token.model_dump()
operation_log.original_contents = None
operation_log.operation_category = TokenUpdateOperationCategory.ISSUE.value
operation_log.operation_category = TokenUpdateOperationCategory.ISSUE
db.add(operation_log)

await db.commit()
Expand Down Expand Up @@ -572,10 +572,10 @@ async def update_bond_token(
operation_log = TokenUpdateOperationLog()
operation_log.token_address = token_address
operation_log.issuer_address = issuer_address
operation_log.type = TokenType.IBET_STRAIGHT_BOND.value
operation_log.type = TokenType.IBET_STRAIGHT_BOND
operation_log.arguments = update_data.model_dump(exclude_none=True)
operation_log.original_contents = original_contents
operation_log.operation_category = TokenUpdateOperationCategory.UPDATE.value
operation_log.operation_category = TokenUpdateOperationCategory.UPDATE
db.add(operation_log)

await db.commit()
Expand Down Expand Up @@ -620,15 +620,15 @@ async def list_bond_operation_log_history(
)
stmt = stmt.where(
TokenUpdateOperationLog.created
>= local_tz.localize(_created_from).astimezone(utc_tz)
>= local_tz.localize(_created_from).astimezone(utc_tz).replace(tzinfo=None)
)
if request_query.created_to:
_created_to = datetime.strptime(
request_query.created_to + ".999999", "%Y-%m-%d %H:%M:%S.%f"
)
stmt = stmt.where(
TokenUpdateOperationLog.created
<= local_tz.localize(_created_to).astimezone(utc_tz)
<= local_tz.localize(_created_to).astimezone(utc_tz).replace(tzinfo=None)
)

count = await db.scalar(select(func.count()).select_from(stmt.subquery()))
Expand Down Expand Up @@ -972,15 +972,15 @@ async def issue_additional_bonds_in_batch(
raise InvalidParameterError("this token is temporarily unavailable")

# Generate upload_id
upload_id = uuid.uuid4()
upload_id = str(uuid.uuid4())

# Add batch data
_batch_upload = BatchIssueRedeemUpload()
_batch_upload.upload_id = upload_id
_batch_upload.issuer_address = issuer_address
_batch_upload.token_type = TokenType.IBET_STRAIGHT_BOND.value
_batch_upload.token_type = TokenType.IBET_STRAIGHT_BOND
_batch_upload.token_address = token_address
_batch_upload.category = BatchIssueRedeemProcessingCategory.ISSUE.value
_batch_upload.category = BatchIssueRedeemProcessingCategory.ISSUE
_batch_upload.processed = False
db.add(_batch_upload)

Expand Down Expand Up @@ -1385,15 +1385,15 @@ async def redeem_bonds_in_batch(
raise InvalidParameterError("this token is temporarily unavailable")

# Generate upload_id
upload_id = uuid.uuid4()
upload_id = str(uuid.uuid4())

# Add batch data
_batch_upload = BatchIssueRedeemUpload()
_batch_upload.upload_id = upload_id
_batch_upload.issuer_address = issuer_address
_batch_upload.token_type = TokenType.IBET_STRAIGHT_BOND.value
_batch_upload.token_type = TokenType.IBET_STRAIGHT_BOND
_batch_upload.token_address = token_address
_batch_upload.category = BatchIssueRedeemProcessingCategory.REDEEM.value
_batch_upload.category = BatchIssueRedeemProcessingCategory.REDEEM
_batch_upload.processed = False
db.add(_batch_upload)

Expand Down Expand Up @@ -1548,7 +1548,7 @@ async def list_all_scheduled_bond_token_update_events(
{
"scheduled_event_id": _token_event.event_id,
"token_address": token_address,
"token_type": TokenType.IBET_STRAIGHT_BOND.value,
"token_type": TokenType.IBET_STRAIGHT_BOND,
"scheduled_datetime": scheduled_datetime_utc.astimezone(
local_tz
).isoformat(),
Expand Down Expand Up @@ -1653,8 +1653,10 @@ async def schedule_bond_token_update_event(
_scheduled_event.event_id = str(uuid.uuid4())
_scheduled_event.issuer_address = issuer_address
_scheduled_event.token_address = token_address
_scheduled_event.token_type = TokenType.IBET_STRAIGHT_BOND.value
_scheduled_event.scheduled_datetime = event_data.scheduled_datetime
_scheduled_event.token_type = TokenType.IBET_STRAIGHT_BOND
_scheduled_event.scheduled_datetime = event_data.scheduled_datetime.astimezone(
UTC
).replace(tzinfo=None)
_scheduled_event.event_type = event_data.event_type
_scheduled_event.data = event_data.data.model_dump()
_scheduled_event.status = 0
Expand Down Expand Up @@ -1758,8 +1760,10 @@ async def schedule_bond_token_update_events_in_batch(
_scheduled_event.event_id = str(uuid.uuid4())
_scheduled_event.issuer_address = issuer_address
_scheduled_event.token_address = token_address
_scheduled_event.token_type = TokenType.IBET_STRAIGHT_BOND.value
_scheduled_event.scheduled_datetime = event_data.scheduled_datetime
_scheduled_event.token_type = TokenType.IBET_STRAIGHT_BOND
_scheduled_event.scheduled_datetime = event_data.scheduled_datetime.astimezone(
UTC
).replace(tzinfo=None)
_scheduled_event.event_type = event_data.event_type
_scheduled_event.data = event_data.data.model_dump()
_scheduled_event.status = 0
Expand Down Expand Up @@ -1827,7 +1831,7 @@ async def retrieve_scheduled_bond_token_update_event(
{
"scheduled_event_id": _token_event.event_id,
"token_address": token_address,
"token_type": TokenType.IBET_STRAIGHT_BOND.value,
"token_type": TokenType.IBET_STRAIGHT_BOND,
"scheduled_datetime": scheduled_datetime_utc.astimezone(
local_tz
).isoformat(),
Expand Down Expand Up @@ -1897,7 +1901,7 @@ async def delete_scheduled_bond_token_update_event(
rtn = {
"scheduled_event_id": _token_event.event_id,
"token_address": token_address,
"token_type": TokenType.IBET_STRAIGHT_BOND.value,
"token_type": TokenType.IBET_STRAIGHT_BOND,
"scheduled_datetime": scheduled_datetime_utc.astimezone(local_tz).isoformat(),
"event_type": _token_event.event_type,
"status": _token_event.status,
Expand Down Expand Up @@ -3206,12 +3210,16 @@ async def list_bond_token_transfer_history(
if query.block_timestamp_from is not None:
stmt = stmt.where(
IDXTransfer.block_timestamp
>= local_tz.localize(query.block_timestamp_from).astimezone(UTC)
>= local_tz.localize(query.block_timestamp_from)
.astimezone(UTC)
.replace(tzinfo=None)
)
if query.block_timestamp_to is not None:
stmt = stmt.where(
IDXTransfer.block_timestamp
<= local_tz.localize(query.block_timestamp_to).astimezone(UTC)
<= local_tz.localize(query.block_timestamp_to)
.astimezone(UTC)
.replace(tzinfo=None)
)
if query.from_address is not None:
stmt = stmt.where(IDXTransfer.from_address == query.from_address)
Expand Down Expand Up @@ -4276,13 +4284,13 @@ async def bulk_transfer_bond_token_ownership(
)

# Generate upload_id
upload_id = uuid.uuid4()
upload_id = str(uuid.uuid4())

# Add bulk transfer upload record
_bulk_transfer_upload = BulkTransferUpload()
_bulk_transfer_upload.upload_id = upload_id
_bulk_transfer_upload.issuer_address = issuer_address
_bulk_transfer_upload.token_type = TokenType.IBET_STRAIGHT_BOND.value
_bulk_transfer_upload.token_type = TokenType.IBET_STRAIGHT_BOND
_bulk_transfer_upload.token_address = token_address
_bulk_transfer_upload.status = 0
db.add(_bulk_transfer_upload)
Expand All @@ -4293,7 +4301,7 @@ async def bulk_transfer_bond_token_ownership(
_bulk_transfer.issuer_address = issuer_address
_bulk_transfer.upload_id = upload_id
_bulk_transfer.token_address = _transfer.token_address
_bulk_transfer.token_type = TokenType.IBET_STRAIGHT_BOND.value
_bulk_transfer.token_type = TokenType.IBET_STRAIGHT_BOND
_bulk_transfer.from_address = _transfer.from_address
_bulk_transfer.to_address = _transfer.to_address
_bulk_transfer.amount = _transfer.amount
Expand Down
12 changes: 6 additions & 6 deletions app/routers/issuer/settlement_issuer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,16 @@ async def list_all_dvp_deliveries(
if request_query.create_blocktimestamp_from is not None:
stmt = stmt.where(
IDXDelivery.create_blocktimestamp
>= local_tz.localize(request_query.create_blocktimestamp_from).astimezone(
tz=UTC
)
>= local_tz.localize(request_query.create_blocktimestamp_from)
.astimezone(tz=UTC)
.replace(tzinfo=None)
)
if request_query.create_blocktimestamp_to is not None:
stmt = stmt.where(
IDXDelivery.create_blocktimestamp
<= local_tz.localize(request_query.create_blocktimestamp_to).astimezone(
tz=UTC
)
<= local_tz.localize(request_query.create_blocktimestamp_to)
.astimezone(tz=UTC)
.replace(tzinfo=None)
)

count = await db.scalar(select(func.count()).select_from(stmt.subquery()))
Expand Down
Loading

0 comments on commit 1029107

Please sign in to comment.