Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests for oracle sources #212

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/driftpy/accounts/bulk_account_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def add_account(
if existing_size == 0:
self._start_loading()

# If the account is already loaded, call the callback immediately
if existing_account_to_load is not None:
buffer_and_slot = self.buffer_and_slot_map.get(pubkey_str)
if buffer_and_slot is not None and buffer_and_slot.buffer is not None:
self.handle_callbacks(existing_account_to_load, buffer_and_slot.buffer, buffer_and_slot.slot)

return callback_id

def get_callback_id(self) -> int:
Expand Down
67 changes: 37 additions & 30 deletions src/driftpy/accounts/polling/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_state_public_key,
)
from driftpy.constants.config import find_all_market_and_oracles
from driftpy.oracles.oracle_id import get_oracle_id
from driftpy.types import (
OracleInfo,
OraclePriceData,
Expand All @@ -38,6 +39,7 @@ def __init__(
self.program = program
self.is_subscribed = False
self.callbacks: dict[str, int] = {}
self.oracle_callbacks: dict[str, int] = {}

self.perp_market_indexes = perp_market_indexes
self.spot_market_indexes = spot_market_indexes
Expand All @@ -49,7 +51,9 @@ def __init__(
self.spot_markets = {}
self.oracle = {}
self.perp_oracle_map: dict[int, Pubkey] = {}
self.perp_oracle_strings_map: dict[int, str] = {}
self.spot_oracle_map: dict[int, Pubkey] = {}
self.spot_oracle_strings_map: dict[int, str] = {}

async def subscribe(self):
if len(self.callbacks) != 0:
Expand Down Expand Up @@ -138,41 +142,38 @@ def cb(buffer: bytes, slot: int):
return cb

async def add_oracle(self, oracle: Pubkey, oracle_source: OracleSource):
if oracle == Pubkey.default() or oracle in self.oracle:
return True

oracle_str = str(oracle)
if oracle_str in self.callbacks:
oracle_id = get_oracle_id(oracle, oracle_source)
if oracle == Pubkey.default() or oracle_id in self.oracle:
return True

callback_id = self.bulk_account_loader.add_account(
oracle, self._get_oracle_callback(oracle_str, oracle_source)
oracle, self._get_oracle_callback(oracle_id, oracle_source)
)
self.callbacks[oracle_str] = callback_id
self.oracle_callbacks[oracle_id] = callback_id

await self._wait_for_oracle(3, oracle_str)
await self._wait_for_oracle(3, oracle_id)

return True

async def _wait_for_oracle(self, tries: int, oracle: str):
async def _wait_for_oracle(self, tries: int, oracle_id: str):
while tries > 0:
await asyncio.sleep(self.bulk_account_loader.frequency)
if oracle in self.bulk_account_loader.buffer_and_slot_map:
if oracle_id in self.oracle:
return
tries -= 1
print(
f"WARNING: Oracle: {oracle} not found after {tries * self.bulk_account_loader.frequency} seconds, Location: {stack_trace()}"
f"WARNING: Oracle: {oracle_id} not found after {tries * self.bulk_account_loader.frequency} seconds, Location: {stack_trace()}"
)

def _get_oracle_callback(self, oracle_str: str, oracle_source: OracleSource):
def _get_oracle_callback(self, oracle_id: str, oracle_source: OracleSource):
decode = get_oracle_decode_fn(oracle_source)

def cb(buffer: bytes, slot: int):
if buffer is None:
return

decoded_data = decode(buffer)
self.oracle[oracle_str] = DataAndSlot(slot, decoded_data)
self.oracle[oracle_id] = DataAndSlot(slot, decoded_data)

return cb

Expand All @@ -181,7 +182,14 @@ async def unsubscribe(self):
self.bulk_account_loader.remove_account(
Pubkey.from_string(pubkey_str), callback_id
)

for oracle_id, callback_id in self.oracle_callbacks.items():
self.bulk_account_loader.remove_account(
Pubkey.from_string(oracle_id.split("-")[0]), callback_id
)

self.callbacks.clear()
self.oracle_callbacks.clear()

def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
return self.state
Expand All @@ -197,9 +205,9 @@ def get_spot_market_and_slot(
return self.spot_markets.get(market_index)

def get_oracle_price_data_and_slot(
self, oracle: Pubkey
self, oracle_id: str
) -> Optional[DataAndSlot[OraclePriceData]]:
return self.oracle.get(str(oracle))
return self.oracle.get(oracle_id)

def get_market_accounts_and_slots(self) -> list[DataAndSlot[PerpMarketAccount]]:
return [
Expand All @@ -221,53 +229,52 @@ async def _set_perp_oracle_map(self):
market_account = market.data
market_index = market_account.market_index
oracle = market_account.amm.oracle
if oracle not in self.oracle:
await self.add_oracle(oracle, market_account.amm.oracle_source)
oracle_source = market_account.amm.oracle_source
oracle_id = get_oracle_id(oracle, oracle_source)
if oracle_id not in self.oracle:
await self.add_oracle(oracle, oracle_source)
self.perp_oracle_map[market_index] = oracle
self.perp_oracle_strings_map[market_index] = oracle_id

async def _set_spot_oracle_map(self):
spot_markets = self.get_spot_market_accounts_and_slots()
for market in spot_markets:
market_account = market.data
market_index = market_account.market_index
oracle = market_account.oracle
if oracle not in self.oracle:
await self.add_oracle(oracle, market_account.oracle_source)
oracle_source = market_account.oracle_source
oracle_id = get_oracle_id(oracle, oracle_source)
if oracle_id not in self.oracle:
await self.add_oracle(oracle, oracle_source)
self.spot_oracle_map[market_index] = oracle
self.spot_oracle_strings_map[market_index] = oracle_id

def get_oracle_price_data_and_slot_for_perp_market(
self, market_index: int
) -> Union[DataAndSlot[OraclePriceData], None]:
print(
"==> PollingDriftClientAccountSubscriber: Getting oracle price data for perp market",
market_index,
)
print(self.perp_markets)
print(self.spot_markets)
perp_market_account = self.get_perp_market_and_slot(market_index)
oracle = self.perp_oracle_map.get(market_index)

print("Perp market account: ", perp_market_account)
print("Oracle: ", oracle)
oracle_id = self.perp_oracle_strings_map.get(market_index)

if not perp_market_account or not oracle:
return None

if str(perp_market_account.data.amm.oracle) != str(oracle):
asyncio.create_task(self._set_perp_oracle_map())

return self.get_oracle_price_data_and_slot(oracle)
return self.get_oracle_price_data_and_slot(oracle_id)

def get_oracle_price_data_and_slot_for_spot_market(
self, market_index: int
) -> Union[DataAndSlot[OraclePriceData], None]:
spot_market_account = self.get_spot_market_and_slot(market_index)
oracle = self.spot_oracle_map.get(market_index)
oracle_id = self.spot_oracle_strings_map.get(market_index)

if not spot_market_account or not oracle:
return None

if str(spot_market_account.data.oracle) != str(oracle):
asyncio.create_task(self._set_spot_oracle_map())

return self.get_oracle_price_data_and_slot(oracle)
return self.get_oracle_price_data_and_slot(oracle_id)
13 changes: 9 additions & 4 deletions src/driftpy/accounts/ws/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def __init__(
self.spot_market_map = None
self.perp_market_map = None
self.spot_market_oracle_map: dict[int, Pubkey] = {}
self.spot_market_oracle_strings_map: dict[int, str] = {}
self.perp_market_oracle_map: dict[int, Pubkey] = {}
self.perp_market_oracle_strings_map: dict[int, str] = {}

async def subscribe(self):
if self.is_subscribed():
Expand Down Expand Up @@ -178,7 +180,7 @@ async def subscribe_to_oracle(self, full_oracle_wrapper: FullOracleWrapper):

async def subscribe_to_oracle_info(self, oracle_info: OracleInfo):
oracle_id = get_oracle_id(oracle_info.pubkey, oracle_info.source)
if oracle_id == Pubkey.default():
if oracle_info.pubkey == Pubkey.default():
return

if oracle_id in self.oracle_subscribers:
Expand Down Expand Up @@ -221,7 +223,7 @@ async def add_oracle(self, oracle_info: OracleInfo):
if oracle_id in self.oracle_subscribers:
return True

if oracle_id == Pubkey.default():
if oracle_info.pubkey == Pubkey.default():
return True

return await self.subscribe_to_oracle_info(oracle_info)
Expand Down Expand Up @@ -299,6 +301,7 @@ async def _set_perp_oracle_map(self):
OracleInfo(oracle, perp_market_account.amm.oracle_source)
)
self.perp_market_oracle_map[market_index] = oracle
self.perp_market_oracle_strings_map[market_index] = oracle_id

async def _set_spot_oracle_map(self):
spot_markets = self.get_spot_market_accounts_and_slots()
Expand All @@ -315,33 +318,35 @@ async def _set_spot_oracle_map(self):
OracleInfo(oracle, spot_market_account.oracle_source)
)
self.spot_market_oracle_map[market_index] = oracle
self.spot_market_oracle_strings_map[market_index] = oracle_id

def get_oracle_price_data_and_slot_for_perp_market(
self, market_index: int
) -> Union[DataAndSlot[OraclePriceData], None]:
perp_market_account = self.get_perp_market_and_slot(market_index)
oracle = self.perp_market_oracle_map.get(market_index)
oracle_id = self.perp_market_oracle_strings_map.get(market_index)

if not perp_market_account or not oracle:
return None

if perp_market_account.data.amm.oracle != oracle:
asyncio.create_task(self._set_perp_oracle_map())

oracle_id = get_oracle_id(oracle, perp_market_account.data.amm.oracle_source)
return self.get_oracle_price_data_and_slot(oracle_id)

def get_oracle_price_data_and_slot_for_spot_market(
self, market_index: int
) -> Union[DataAndSlot[OraclePriceData], None]:
spot_market_account = self.get_spot_market_and_slot(market_index)
oracle = self.spot_market_oracle_map.get(market_index)
oracle_id = self.spot_market_oracle_strings_map.get(market_index)

if not spot_market_account or not oracle:
return None

if spot_market_account.data.oracle != oracle:
asyncio.create_task(self._set_spot_oracle_map())

oracle_id = get_oracle_id(oracle, spot_market_account.data.oracle_source)
return self.get_oracle_price_data_and_slot(oracle_id)

54 changes: 48 additions & 6 deletions src/driftpy/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from driftpy.drift_client import (
DriftClient,
)
from driftpy.constants.numeric_constants import PEG_PRECISION
from driftpy.types import OracleGuardRails, OracleSource, PrelaunchOracleParams
from driftpy.constants.numeric_constants import BASE_PRECISION, PEG_PRECISION, PRICE_PRECISION
from driftpy.types import AssetTier, ContractTier, OracleGuardRails, OracleSource, PrelaunchOracleParams
from driftpy.addresses import *
from driftpy.accounts import get_state_account
from driftpy.constants.numeric_constants import (
Expand Down Expand Up @@ -61,10 +61,24 @@ async def initialize_perp_market(
periodicity: int,
peg_multiplier: int = PEG_PRECISION,
oracle_source: OracleSource = OracleSource.Pyth(),
contract_tier: ContractTier = ContractTier.Speculative(),
margin_ratio_initial: int = 2000,
margin_ratio_maintenance: int = 500,
liquidation_fee: int = 0,
liquidator_fee: int = 0,
if_liquidator_fee: int = 10000,
imf_factor: int = 0,
active_status: bool = True,
base_spread: int = 0,
max_spread: int = 142500,
max_open_interest: int = 0,
max_revenue_withdraw_per_period: int = 0,
quote_max_insurance: int = 0,
order_step_size: int = BASE_PRECISION // 10000,
order_tick_size: int = PRICE_PRECISION // 100000,
min_order_size: int = BASE_PRECISION // 10000,
concentration_coef_scale: int = 1,
curve_update_intensity: int = 0,
amm_jit_intensity: int = 0,
name: list = [0] * 32,
) -> Signature:
state_public_key = get_state_public_key(self.program.program_id)
Expand All @@ -81,10 +95,24 @@ async def initialize_perp_market(
periodicity,
peg_multiplier,
oracle_source,
contract_tier,
margin_ratio_initial,
margin_ratio_maintenance,
liquidation_fee,
liquidator_fee,
if_liquidator_fee,
imf_factor,
active_status,
base_spread,
max_spread,
max_open_interest,
max_revenue_withdraw_per_period,
quote_max_insurance,
order_step_size,
order_tick_size,
min_order_size,
concentration_coef_scale,
curve_update_intensity,
amm_jit_intensity,
name,
ctx=Context(
accounts={
Expand All @@ -111,7 +139,14 @@ async def initialize_spot_market(
initial_liability_weight: int = SPOT_WEIGHT_PRECISION,
maintenance_liability_weight: int = SPOT_WEIGHT_PRECISION,
imf_factor: int = 0,
liquidation_fee: int = 0,
liquidator_fee: int = 0,
if_liquidation_fee: int = 0,
scale_initial_asset_weight_start: int = 0,
withdraw_guard_threshold: int = 0,
order_tick_size: int = 1,
order_step_size: int = 1,
if_total_factor: int = 0,
asset_tier: AssetTier = AssetTier.COLLATERAL(),
active_status: bool = True,
name: list = [0] * 32,
):
Expand All @@ -137,8 +172,15 @@ async def initialize_spot_market(
initial_liability_weight,
maintenance_liability_weight,
imf_factor,
liquidation_fee,
liquidator_fee,
if_liquidation_fee,
active_status,
asset_tier,
scale_initial_asset_weight_start,
withdraw_guard_threshold,
order_tick_size,
order_step_size,
if_total_factor,
name,
ctx=Context(
accounts={
Expand Down
4 changes: 2 additions & 2 deletions src/driftpy/setup/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ async def mock_oracle(


async def initialize_sol_spot_market(
admin: Admin, sol_oracle: Pubkey, sol_mint: Pubkey = NATIVE_MINT
admin: Admin, sol_oracle: Pubkey, sol_mint: Pubkey = NATIVE_MINT, oracle_source: OracleSource = OracleSource.Pyth()
):
optimal_utilization = SPOT_RATE_PRECISION // 2
optimal_rate = SPOT_RATE_PRECISION * 20
Expand All @@ -352,7 +352,7 @@ async def initialize_sol_spot_market(
optimal_rate,
max_rate,
sol_oracle,
OracleSource.Pyth(),
oracle_source,
initial_asset_weight,
maintenance_asset_weight,
initial_liability_weight,
Expand Down
Loading
Loading