Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add grpc #215

Merged
merged 15 commits into from
Jan 2, 2025
90 changes: 90 additions & 0 deletions examples/grpc_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import asyncio
import os

from anchorpy.provider import Provider, Wallet
from dotenv import load_dotenv
from solana.rpc.async_api import AsyncClient
from solana.rpc.commitment import Commitment
from solders.keypair import Keypair

from driftpy.drift_client import AccountSubscriptionConfig, DriftClient
from driftpy.types import GrpcConfig

load_dotenv()

RED = "\033[91m"
GREEN = "\033[92m"
RESET = "\033[0m"

CLEAR_SCREEN = "\033c"


async def watch_drift_markets():
rpc_fqdn = os.environ.get("RPC_FQDN")
x_token = os.environ.get("X_TOKEN")
private_key = os.environ.get("PRIVATE_KEY")
rpc_url = os.environ.get("RPC_TRITON")

if not (rpc_fqdn and x_token and private_key and rpc_url):
raise ValueError("RPC_FQDN, X_TOKEN, PRIVATE_KEY, and RPC_TRITON must be set")

wallet = Wallet(Keypair.from_base58_string(private_key))
connection = AsyncClient(rpc_url)
provider = Provider(connection, wallet)

drift_client = DriftClient(
provider.connection,
provider.wallet,
"mainnet",
account_subscription=AccountSubscriptionConfig(
"grpc",
grpc_config=GrpcConfig(
endpoint=rpc_fqdn,
token=x_token,
commitment=Commitment("confirmed"),
),
),
)

await drift_client.subscribe()
print("Subscribed via gRPC. Listening for market updates...")

previous_prices = {}

while True:
print(CLEAR_SCREEN, end="")

perp_markets = drift_client.get_perp_market_accounts()

if not perp_markets:
print(f"{RED}No perp markets found (yet){RESET}")
else:
print("Drift Perp Markets (gRPC subscription)\n")
perp_markets.sort(key=lambda x: x.market_index)
for market in perp_markets[:20]:
market_index = market.market_index
last_price = market.amm.historical_oracle_data.last_oracle_price / 1e6

if market_index in previous_prices:
old_price = previous_prices[market_index]
if last_price > old_price:
color = GREEN
elif last_price < old_price:
color = RED
else:
color = RESET
else:
color = RESET

print(
f"Market Index: {market_index} | "
f"Price: {color}${last_price:.4f}{RESET}"
)

previous_prices[market_index] = last_price

await asyncio.sleep(1)


if __name__ == "__main__":
asyncio.run(watch_drift_markets())
781 changes: 408 additions & 373 deletions poetry.lock

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ mypy = "^1.7.0"
deprecated = "^1.2.14"
events = "^0.5"
numpy = "^1.26.2"
jito-searcher-client = "0.1.4"
# jito-searcher-client = "0.1.5"
grpcio = "1.68.1"
protobuf = "5.29.2"

[tool.poetry.dev-dependencies]
pytest = "^7.2.0"
Expand All @@ -97,6 +99,9 @@ mkdocs-material = "^8.1.8"
bump2version = "^1.0.1"
autopep8 = "^2.0.4"
mypy = "^1.7.0"
python-dotenv = "^1.0.0"
ruff = "^0.8.4"


[tool.poetry.group.dev.dependencies]
pytest = "^7.4.4"
Expand All @@ -111,7 +116,7 @@ build-backend = "poetry.core.masonry.api"
asyncio_mode = "strict"

[tool.ruff]
exclude = [".git", "__pycache__", "docs/source/conf.py", "old", "build", "dist"]
exclude = [".git", "__pycache__", "docs/source/conf.py", "old", "build", "dist", "**/geyser_codegen/**"]

[tool.ruff.pycodestyle]
max-line-length = 88
Expand Down
31 changes: 30 additions & 1 deletion src/driftpy/account_subscription_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
DemoDriftClientAccountSubscriber,
DemoUserAccountSubscriber,
)
from driftpy.accounts.grpc.account_subscriber import GrpcConfig
from driftpy.accounts.grpc.drift_client import GrpcDriftClientAccountSubscriber
from driftpy.accounts.grpc.user import GrpcUserAccountSubscriber
from driftpy.accounts.polling import (
PollingDriftClientAccountSubscriber,
PollingUserAccountSubscriber,
Expand All @@ -32,13 +35,17 @@ def default():

def __init__(
self,
account_subscription_type: Literal["polling", "websocket", "cached", "demo"],
account_subscription_type: Literal[
"polling", "websocket", "cached", "demo", "grpc"
],
bulk_account_loader: Optional[BulkAccountLoader] = None,
commitment: Commitment = Commitment("confirmed"),
grpc_config: Optional[GrpcConfig] = None,
):
self.type = account_subscription_type
self.commitment = commitment
self.bulk_account_loader = None
self.grpc_config = grpc_config

if self.type != "polling":
return
Expand Down Expand Up @@ -117,6 +124,18 @@ def get_drift_client_subscriber(
oracle_infos,
self.commitment,
)
case "grpc":
if self.grpc_config is None:
raise ValueError("A grpc config is required for grpc subscription")
return GrpcDriftClientAccountSubscriber(
program,
self.grpc_config,
perp_market_indexes,
spot_market_indexes,
cast(list[FullOracleWrapper], oracle_infos),
should_find_all_markets_and_oracles,
self.commitment,
)

def get_user_client_subscriber(self, program: Program, user_pubkey: Pubkey):
match self.type:
Expand All @@ -138,3 +157,13 @@ def get_user_client_subscriber(self, program: Program, user_pubkey: Pubkey):
)
case "demo":
return DemoUserAccountSubscriber(user_pubkey, program, self.commitment)
case "grpc":
if self.grpc_config is None:
raise ValueError("A grpc config is required for grpc subscription")
return GrpcUserAccountSubscriber(
grpc_config=self.grpc_config,
account_name="user",
account_public_key=user_pubkey,
program=program,
commitment=self.commitment,
)
155 changes: 155 additions & 0 deletions src/driftpy/accounts/grpc/account_subscriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import asyncio
import time
from typing import Callable, Optional, TypeVar

import grpc.aio
from anchorpy.program.core import Program
from solana.rpc.commitment import Commitment
from solders.pubkey import Pubkey

from driftpy.accounts.grpc.geyser_codegen import geyser_pb2, geyser_pb2_grpc
from driftpy.accounts.types import DataAndSlot
from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber
from driftpy.types import GrpcConfig

T = TypeVar("T")


class TritonAuthMetadataPlugin(grpc.AuthMetadataPlugin):
def __init__(self, x_token: str):
self.x_token = x_token

def __call__(
self,
context: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback,
):
metadata = (("x-token", self.x_token),)
callback(metadata, None)


class GrpcAccountSubscriber(WebsocketAccountSubscriber[T]):
def __init__(
self,
grpc_config: GrpcConfig,
account_name: str,
program: Program,
account_public_key: Pubkey,
commitment: Commitment = Commitment("confirmed"),
decode: Optional[Callable[[bytes], T]] = None,
initial_data: Optional[DataAndSlot[T]] = None,
):
super().__init__(account_public_key, program, commitment, decode, initial_data)
self.client = self._create_grpc_client(grpc_config)
self.stream = None
self.listener_id = None
self.account_name = account_name
self.decode = (
decode if decode is not None else self.program.coder.accounts.decode
)

def _create_grpc_client(self, config: GrpcConfig) -> geyser_pb2_grpc.GeyserStub:
auth = TritonAuthMetadataPlugin(config.token)
ssl_creds = grpc.ssl_channel_credentials()
call_creds = grpc.metadata_call_credentials(auth)
combined_creds = grpc.composite_channel_credentials(ssl_creds, call_creds)

channel = grpc.aio.secure_channel(config.endpoint, credentials=combined_creds)
return geyser_pb2_grpc.GeyserStub(channel)

async def subscribe(self) -> Optional[asyncio.Task[None]]:
if self.listener_id is not None:
return

self.task = asyncio.create_task(self._subscribe_grpc())
return self.task

async def _subscribe_grpc(self):
"""Internal method to handle the gRPC subscription"""
if self.data_and_slot is None:
await self.fetch()

try:
request_iterator = self._create_subscribe_request()
self.stream = self.client.Subscribe(request_iterator)
await self.stream.wait_for_connection()

self.listener_id = 1

async for update in self.stream:
try:
if update.HasField("ping") or update.HasField("pong"):
continue

if not update.HasField("account"):
print(f"No account for {self.account_name}")
continue

slot = int(update.account.slot)
account_info = {
"owner": Pubkey.from_bytes(update.account.account.owner),
"lamports": int(update.account.account.lamports),
"data": bytes(update.account.account.data),
"executable": update.account.account.executable,
"rent_epoch": int(update.account.account.rent_epoch),
}

if not account_info["data"]:
print(f"No data for {self.account_name}")
continue

decoded_data = (
self.decode(account_info["data"])
if self.decode
else account_info
)
self.update_data(DataAndSlot(slot, decoded_data))

except Exception as e:
print(f"Error processing account data for {self.account_name}: {e}")
break

except Exception as e:
print(f"Error in gRPC subscription for {self.account_name}: {e}")
if self.stream:
self.stream.cancel()
self.listener_id = None
raise e

async def _create_subscribe_request(self):
request = geyser_pb2.SubscribeRequest()
account_filter = geyser_pb2.SubscribeRequestFilterAccounts()
account_filter.account.append(str(self.pubkey))
account_filter.nonempty_txn_signature = True
request.accounts["account_monitor"].CopyFrom(account_filter)

request.commitment = geyser_pb2.CommitmentLevel.CONFIRMED
if self.commitment == Commitment("finalized"):
request.commitment = geyser_pb2.CommitmentLevel.FINALIZED
if self.commitment == Commitment("processed"):
request.commitment = geyser_pb2.CommitmentLevel.PROCESSED

yield request

while True:
await asyncio.sleep(30)
ping_request = geyser_pb2.SubscribeRequest()
ping_request.ping.id = int(time.time())
yield ping_request

async def unsubscribe(self) -> None:
if self.listener_id is not None:
try:
if self.stream:
self.stream.cancel()
self.listener_id = None
except Exception as e:
print(f"Error unsubscribing from account {self.account_name}: {e}")
raise e

def update_data(self, new_data: Optional[DataAndSlot[T]]):
if new_data is None:
return

if self.data_and_slot is None or new_data.slot >= self.data_and_slot.slot:
self.data_and_slot = new_data
Loading