diff --git a/api/engineapi/contracts_actions.py b/api/engineapi/contracts_actions.py index d95e71de..dfa1e6ec 100644 --- a/api/engineapi/contracts_actions.py +++ b/api/engineapi/contracts_actions.py @@ -164,7 +164,8 @@ def delete_registered_contract( def request_calls( db_session: Session, moonstream_user_id: uuid.UUID, - registered_contract_id: uuid.UUID, + registered_contract_id: Optional[uuid.UUID], + contract_address: Optional[str], call_specs: List[data.CallSpecification], ttl_days: Optional[int] = None, ) -> int: @@ -174,21 +175,31 @@ def request_calls( # TODO(zomglings): Do not pass raw ttl_days into SQL query - could be subject to SQL injection # For now, in the interest of speed, let us just be super cautious with ttl_days. # Check that the ttl_days is indeed an integer + if registered_contract_id is None and contract_address is None: + raise ValueError( + "At least one of registered_contract_id or contract_address is required" + ) + if ttl_days is not None: assert ttl_days == int(ttl_days), "ttl_days must be an integer" if ttl_days <= 0: raise ValueError("ttl_days must be positive") - # Check that the moonstream_user_id matches the RegisteredContract - try: - registered_contract = ( - db_session.query(RegisteredContract) - .filter( - RegisteredContract.id == registered_contract_id, - RegisteredContract.moonstream_user_id == moonstream_user_id, - ) - .one() + # Check that the moonstream_user_id matches a RegisteredContract with the given id or address + query = db_session.query(RegisteredContract).filter( + RegisteredContract.moonstream_user_id == moonstream_user_id + ) + + if registered_contract_id is not None: + query = query.filter(RegisteredContract.id == registered_contract_id) + + if contract_address is not None: + query = query.filter( + RegisteredContract.address == Web3.toChecksumAddress(contract_address) ) + + try: + registered_contract = query.one() except NoResultFound: raise ValueError("Invalid registered_contract_id or moonstream_user_id") diff --git a/api/engineapi/data.py b/api/engineapi/data.py index c68bd73e..f16a0c0d 100644 --- a/api/engineapi/data.py +++ b/api/engineapi/data.py @@ -214,6 +214,8 @@ class CallSpecification(BaseModel): class CreateCallRequestsAPIRequest(BaseModel): + contract_id: Optional[UUID] = None + contract_address: Optional[str] = None specifications: List[CallSpecification] = Field(default_factory=list) ttl_days: Optional[int] = None diff --git a/api/engineapi/routes/contracts.py b/api/engineapi/routes/contracts.py index 050740a6..ce67bc69 100644 --- a/api/engineapi/routes/contracts.py +++ b/api/engineapi/routes/contracts.py @@ -188,10 +188,9 @@ async def list_requests( return requests -@app.post("/{contract_id}/requests") +@app.post("/requests") async def create_requests( request: Request, - contract_id: UUID, data: data.CreateCallRequestsAPIRequest = Body(...), db_session: Session = Depends(db.yield_db_session), ) -> int: @@ -202,7 +201,8 @@ async def create_requests( num_requests = contracts_actions.request_calls( db_session=db_session, moonstream_user_id=request.state.user.id, - registered_contract_id=contract_id, + registered_contract_id=data.contract_id, + contract_address=data.contract_address, call_specs=data.specifications, ttl_days=data.ttl_days, )