Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple Dropper v0.2.0 API improvements #299

Merged
merged 12 commits into from
May 9, 2023
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Fix unique constract on registered_contracts to include moonstream_user_id

Revision ID: dedd8a7d0624
Revises: d1be5f227664
Create Date: 2023-05-02 15:52:36.654980

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "dedd8a7d0624"
down_revision = "d1be5f227664"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(
"uq_registered_contracts_blockchain", "registered_contracts", type_="unique"
)
op.create_unique_constraint(
op.f("uq_registered_contracts_blockchain"),
"registered_contracts",
["blockchain", "moonstream_user_id", "address", "contract_type"],
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(
op.f("uq_registered_contracts_blockchain"),
"registered_contracts",
type_="unique",
)
op.create_unique_constraint(
"uq_registered_contracts_blockchain",
"registered_contracts",
["blockchain", "address", "contract_type"],
)
# ### end Alembic commands ###
138 changes: 103 additions & 35 deletions api/engineapi/contracts_actions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import argparse
from datetime import timedelta
import json
import logging
import uuid
from enum import Enum
from typing import Any, Dict, List, Optional

from sqlalchemy import func, text
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.orm import Session
from web3 import Web3

from .data import ContractType

from . import data, db
from .models import RegisteredContract, CallRequest

Expand All @@ -21,11 +23,6 @@ class ContractAlreadyRegistered(Exception):
pass


class ContractType(Enum):
raw = "raw"
dropper = "dropper-v0.2.0"


def validate_method_and_params(
contract_type: ContractType, method: str, parameters: Dict[str, Any]
) -> None:
Expand Down Expand Up @@ -71,12 +68,6 @@ def register_contract(
"""
Register a contract against the Engine instance
"""

# TODO(zomglings): Make it so that contract_type is passed as a string. Convert to
# ContractType here. That will mean there is a single point at which the validation is
# performed rather than relying on each entrypoint to register_contract having to implement
# their own validation.

try:
contract = RegisteredContract(
blockchain=blockchain,
Expand All @@ -100,6 +91,46 @@ def register_contract(
return render_registered_contract(contract)


def update_registered_contract(
db_session: Session,
moonstream_user_id: uuid.UUID,
contract_id: uuid.UUID,
title: Optional[str] = None,
description: Optional[str] = None,
image_uri: Optional[str] = None,
ignore_nulls: bool = True,
) -> data.RegisteredContract:
"""
Update the registered contract with the given contract ID provided that the user with moonstream_user_id
has access to it.
"""
query = db_session.query(RegisteredContract).filter(
RegisteredContract.id == contract_id,
RegisteredContract.moonstream_user_id == moonstream_user_id,
)

contract = query.one()

if not (title is None and ignore_nulls):
contract.title = title
if not (description is None and ignore_nulls):
contract.description = description
if not (image_uri is None and ignore_nulls):
contract.image_uri = image_uri

try:
db_session.add(contract)
db_session.commit()
except Exception as err:
logger.error(
f"update_registered_contract -- error storing update in database: {repr(err)}"
)
db_session.rollback()
raise

return render_registered_contract(contract)


def lookup_registered_contracts(
db_session: Session,
moonstream_user_id: uuid.UUID,
Expand Down Expand Up @@ -164,7 +195,8 @@ def delete_registered_contract(
def request_calls(
db_session: Session,
moonstream_user_id: uuid.UUID,
registered_contract_id: uuid.UUID,
registered_contract_id: Optional[uuid.UUID],
contract_address: Optional[str],
call_specs: List[data.CallSpecification],
ttl_days: Optional[int] = None,
) -> int:
Expand All @@ -174,21 +206,31 @@ def request_calls(
# TODO(zomglings): Do not pass raw ttl_days into SQL query - could be subject to SQL injection
# For now, in the interest of speed, let us just be super cautious with ttl_days.
# Check that the ttl_days is indeed an integer
if registered_contract_id is None and contract_address is None:
raise ValueError(
"At least one of registered_contract_id or contract_address is required"
)

if ttl_days is not None:
assert ttl_days == int(ttl_days), "ttl_days must be an integer"
if ttl_days <= 0:
raise ValueError("ttl_days must be positive")

# Check that the moonstream_user_id matches the RegisteredContract
try:
registered_contract = (
db_session.query(RegisteredContract)
.filter(
RegisteredContract.id == registered_contract_id,
RegisteredContract.moonstream_user_id == moonstream_user_id,
)
.one()
# Check that the moonstream_user_id matches a RegisteredContract with the given id or address
query = db_session.query(RegisteredContract).filter(
RegisteredContract.moonstream_user_id == moonstream_user_id
)

if registered_contract_id is not None:
query = query.filter(RegisteredContract.id == registered_contract_id)

if contract_address is not None:
query = query.filter(
RegisteredContract.address == Web3.toChecksumAddress(contract_address)
)

try:
registered_contract = query.one()
except NoResultFound:
raise ValueError("Invalid registered_contract_id or moonstream_user_id")

Expand All @@ -202,18 +244,17 @@ def request_calls(
contract_type, specification.method, specification.parameters
)

# Calculate the expiration time (if ttl_days is specified)
expires_at_sql = None
expires_at = None
if ttl_days is not None:
expires_at_sql = text(f"(NOW() + INTERVAL '{ttl_days} days')")
expires_at = func.now() + timedelta(days=ttl_days)

request = CallRequest(
registered_contract_id=registered_contract.id,
caller=normalized_caller,
moonstream_user_id=moonstream_user_id,
method=specification.method,
parameters=specification.parameters,
expires_at=expires_at_sql,
expires_at=expires_at,
)

db_session.add(request)
Expand All @@ -229,21 +270,42 @@ def request_calls(

def list_call_requests(
db_session: Session,
registered_contract_id: uuid.UUID,
caller: str,
contract_id: Optional[uuid.UUID],
contract_address: Optional[str],
caller: Optional[str],
limit: int = 10,
offset: Optional[int] = None,
show_expired: bool = False,
) -> List[data.CallRequest]:
"""
List call requests for the given moonstream_user_id
"""
if caller is None:
raise ValueError("caller must be specified")

if contract_id is None and contract_address is None:
raise ValueError(
"At least one of contract_id or contract_address must be specified"
)

# If show_expired is False, filter out expired requests using current time on database server
query = db_session.query(CallRequest).filter(
CallRequest.registered_contract_id == registered_contract_id,
CallRequest.caller == Web3.toChecksumAddress(caller),
query = (
db_session.query(CallRequest, RegisteredContract)
.join(
RegisteredContract,
CallRequest.registered_contract_id == RegisteredContract.id,
)
.filter(CallRequest.caller == Web3.toChecksumAddress(caller))
)

if contract_id is not None:
query = query.filter(CallRequest.registered_contract_id == contract_id)

if contract_address is not None:
query = query.filter(
RegisteredContract.address == Web3.toChecksumAddress(contract_address)
)

if not show_expired:
query = query.filter(
CallRequest.expires_at > func.now(),
Expand All @@ -254,7 +316,10 @@ def list_call_requests(

query = query.limit(limit)
results = query.all()
return [render_call_request(call_request) for call_request in results]
return [
render_call_request(call_request, registered_contract)
for call_request, registered_contract in results
]


# TODO(zomglings): What should the delete functionality for call requests look like?
Expand Down Expand Up @@ -282,10 +347,13 @@ def render_registered_contract(contract: RegisteredContract) -> data.RegisteredC
)


def render_call_request(call_request: CallRequest) -> data.CallRequest:
def render_call_request(
call_request: CallRequest, registered_contract: RegisteredContract
) -> data.CallRequest:
return data.CallRequest(
id=call_request.id,
registered_contract_id=call_request.registered_contract_id,
contract_id=call_request.registered_contract_id,
contract_address=registered_contract.address,
moonstream_user_id=call_request.moonstream_user_id,
caller=call_request.caller,
method=call_request.method,
Expand Down Expand Up @@ -404,7 +472,7 @@ def handle_list_requests(args: argparse.Namespace) -> None:
with db.yield_db_session_ctx() as db_session:
call_requests = list_call_requests(
db_session=db_session,
registered_contract_id=args.registered_contract_id,
contract_id=args.registered_contract_id,
caller=args.caller,
limit=args.limit,
offset=args.offset,
Expand Down
46 changes: 38 additions & 8 deletions api/engineapi/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field, validator
from uuid import UUID

from pydantic import BaseModel, Field, validator, root_validator
from web3 import Web3


class PingResponse(BaseModel):
"""
Expand Down Expand Up @@ -53,7 +55,6 @@ class DropperBlockchainResponse(BaseModel):


class DropRegisterRequest(BaseModel):

dropper_contract_id: UUID
title: Optional[str] = None
description: Optional[str] = None
Expand Down Expand Up @@ -177,13 +178,25 @@ class DropUpdatedResponse(BaseModel):
active: bool = True


class ContractType(Enum):
raw = "raw"
dropper = "dropper-v0.2.0"


class RegisterContractRequest(BaseModel):
blockchain: str
address: str
contract_type: str
contract_type: ContractType
title: Optional[str] = None
description: Optional[str] = None
image_uri: Optional[str] = None


class UpdateContractRequest(BaseModel):
title: Optional[str] = None
description: Optional[str] = None
image_uri: Optional[str] = None
ignore_nulls: bool = True


class RegisteredContract(BaseModel):
Expand Down Expand Up @@ -214,28 +227,45 @@ class CallSpecification(BaseModel):


class CreateCallRequestsAPIRequest(BaseModel):
contract_id: Optional[UUID] = None
contract_address: Optional[str] = None
specifications: List[CallSpecification] = Field(default_factory=list)
ttl_days: Optional[int] = None

# Solution found thanks to https://github.com/pydantic/pydantic/issues/506
@root_validator
def at_least_one_of_contract_id_and_contract_address(cls, values):
if values.get("contract_id") is None and values.get("contract_address") is None:
raise ValueError(
"At least one of contract_id and contract_address must be provided"
)
return values


class CallRequest(BaseModel):
id: UUID
registered_contract_id: UUID
contract_id: UUID
contract_address: str
moonstream_user_id: UUID
caller: str
method: str
parameters: Dict[str, Any]
expires_at: datetime
expires_at: Optional[datetime]
created_at: datetime
updated_at: datetime

@validator("id", "registered_contract_id", "moonstream_user_id")
@validator("id", "contract_id", "moonstream_user_id")
def validate_uuids(cls, v):
return str(v)

@validator("created_at", "updated_at", "expires_at")
def validate_datetimes(cls, v):
return v.isoformat()
if v is not None:
return v.isoformat()

@validator("contract_address", "caller")
def validate_web3_adresses(cls, v):
return Web3.toChecksumAddress(v)


class QuartilesResponse(BaseModel):
Expand Down
Loading