Skip to content

Commit

Permalink
111 http perf python fix (#119)
Browse files Browse the repository at this point in the history
* update .gitignore

* removing comments, small fixes to v1

* remove copies from orderbook loop

* dates were unsorted

* fix cancel_order logic to uistv2 to return cancelled order id

* add new ref to normal orders

* add order_id_ref to modify_order result

* remove clone on cancel order

* fix python grid strat

* fixed bug with last tick

* fix memory usage in http tick

* correct logging level in Python

* fmt
  • Loading branch information
calumrussell authored Nov 9, 2024
1 parent a706a4d commit a029b80
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 65 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
*.csv
/.vscode
Cargo.lock
data/venv/
rotala-python/venv/
test_data/
*__pycache__*
36 changes: 21 additions & 15 deletions rotala-http/src/http/uist_v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,31 @@ impl AppState {

if let Some(quotes) = dataset.get_quotes(&curr_date) {
let mut res = backtest.exchange.tick(quotes, curr_date);
executed_orders.append(&mut res.0);
inserted_orders.append(&mut res.1);

executed_orders = std::mem::take(&mut res.0);
inserted_orders = std::mem::take(&mut res.1);
}

let new_pos = backtest.pos + 1;
let new_date = *dataset.get_date(new_pos).unwrap();

if dataset.has_next(new_pos) {
has_next = true;
backtest.date = new_date;
if let Some(new_date) = (*dataset).get_date(new_pos) {
if dataset.has_next(new_pos) {
has_next = true;
backtest.date = *new_date;
}
backtest.pos = new_pos;
let bbo = dataset.get_bbo(new_date).unwrap();
//Have to clone here because we can't mutate immutable dataset
let depth = dataset.get_quotes(new_date).unwrap().clone();
return Some((has_next, executed_orders, inserted_orders, bbo, depth));
} else {
return Some((
false,
Vec::new(),
Vec::new(),
HashMap::new(),
HashMap::new(),
));
}
backtest.pos = new_pos;

let bbo = dataset.get_bbo(new_date).unwrap();
//TODO: shouldn't clone here
let depth = dataset.get_quotes(&new_date).unwrap().clone();

return Some((has_next, executed_orders, inserted_orders, bbo, depth));
}
}
None
Expand Down Expand Up @@ -288,7 +295,6 @@ pub mod server {
mut insert_order: web::Json<InsertOrderRequest>,
) -> Result<web::Json<()>, UistV2Error> {
let (backtest_id,) = path.into_inner();
//TODO: shouldn't need clone here
let take_orders = std::mem::take(&mut insert_order.orders);
if let Some(()) = app.insert_orders(take_orders, backtest_id) {
Ok(web::Json(()))
Expand Down
12 changes: 7 additions & 5 deletions rotala-python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def create_orders(bid_grid, ask_grid):


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.CRITICAL)

builder = BrokerBuilder()
builder.init_dataset_name("Test")
Expand All @@ -61,19 +61,23 @@ def create_orders(bid_grid, ask_grid):

last_mid = -1
while True:
brkr.tick()

depth = brkr.latest_depth
bid_grid, ask_grid = create_grid(depth)

best_bid, best_ask, mid_price = get_best_and_mid(depth)
if last_mid == -1:
last_mid = mid_price

mid_change = round(abs(last_mid - mid_price), 2)
last_mid = mid_price

risk = risk_management(brkr.unexecuted_orders, brkr.get_current_value())
if len(brkr.unexecuted_orders) == 0:
[brkr.insert_order(order) for order in create_orders(bid_grid, ask_grid)]
else:
mid_change = round(abs(last_mid - mid_price), 2)
if mid_change > 0.4:
if mid_change > 0.1:
# In practice, we want to look for overlapping levels so we don't need
# to clear whole book
for order_id in brkr.unexecuted_orders:
Expand All @@ -87,5 +91,3 @@ def create_orders(bid_grid, ask_grid):
brkr.insert_order(order)
for order in create_orders(bid_grid, ask_grid)
]

brkr.tick()
15 changes: 10 additions & 5 deletions rotala-python/src/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,15 @@ def __init__(
date: int,
typ: OrderResultType,
order_id: int,
order_id_ref: int | None,
):
self.symbol = symbol
self.value = value
self.quantity = quantity
self.date = date
self.typ = typ
self.order_id = order_id
self.order_id_ref = order_id_ref

def __str__(self):
return (
Expand All @@ -145,6 +147,7 @@ def from_dict(from_dict: dict):
from_dict["date"],
trade_type,
from_dict["order_id"],
from_dict["order_id_ref"],
)

@staticmethod
Expand All @@ -157,6 +160,7 @@ def from_json(json_str: str):
to_dict["date"],
to_dict["typ"],
to_dict["order_id"],
to_dict["order_id_ref"],
)


Expand Down Expand Up @@ -188,13 +192,13 @@ def _update_holdings(self, position: str, chg: float):

curr_position = self.holdings[position]
new_position = curr_position + chg
logger.info(
logger.debug(
f"{self.backtest_id}-{self.ts} POSITION CHG: {position} {curr_position} -> {new_position}"
)
self.holdings[position] = new_position

def _process_order_result(self, result: OrderResult):
logger.info(f"{self.backtest_id}-{self.ts} EXECUTED: {result}")
logger.debug(f"{self.backtest_id}-{self.ts} EXECUTED: {result}")

if result.typ == OrderResultType.Buy or result.typ == OrderResultType.Sell:
before_trade = self.cash
Expand All @@ -204,7 +208,7 @@ def _process_order_result(self, result: OrderResult):
else self.cash + result.value
)

logger.info(
logger.debug(
f"{self.backtest_id}-{self.ts} CASH: {before_trade} -> {after_trade}"
)
self.cash = after_trade
Expand All @@ -228,6 +232,7 @@ def _process_order_result(self, result: OrderResult):
else:
if result.typ == OrderResultType.Cancel:
del self.unexecuted_orders[result.order_id]
del self.unexecuted_orders[result.order_id_ref]
else:
logger.critical("Unsupported order modification type")
exit(1)
Expand Down Expand Up @@ -261,7 +266,7 @@ def tick(self):
logger.info(f"{self.backtest_id}-{self.ts} TICK")

# Flush pending orders
logger.info(
logger.debug(
f"{self.backtest_id}-{self.ts} INSERTING {len(self.pending_orders)} ORDER"
)
self.http.insert_orders(self.pending_orders)
Expand Down Expand Up @@ -289,5 +294,5 @@ def tick(self):
self.ts = list(self.latest_quotes.values())[0]["date"]

curr_value = self.get_current_value()
logger.info(f"{self.backtest_id}-{self.ts} TOTAL VALUE: {curr_value}")
logger.debug(f"{self.backtest_id}-{self.ts} TOTAL VALUE: {curr_value}")
self.portfolio_values.append(curr_value)
23 changes: 18 additions & 5 deletions rotala-python/src/http.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
import requests
from urllib3.util import Retry
from requests import Session
from requests.adapters import HTTPAdapter


class HttpClient:
def __init__(self, base_url):
self.base_url = base_url
self.backtest_id = None

s = Session()
retries = Retry(
total=3,
backoff_factor=0.1,
status_forcelist=[502, 503, 504],
allowed_methods={"POST"},
)
s.mount("https://", HTTPAdapter(max_retries=retries))
self.s = s
return

def init(self, dataset_name):
r = requests.get(f"{self.base_url}/init/{dataset_name}")
r = self.s.get(f"{self.base_url}/init/{dataset_name}")
json_response = r.json()
self.backtest_id = int(json_response["backtest_id"])
return json_response
Expand All @@ -17,7 +30,7 @@ def tick(self):
if self.backtest_id is None:
raise ValueError("Called before init")

r = requests.get(f"{self.base_url}/backtest/{self.backtest_id}/tick")
r = self.s.get(f"{self.base_url}/backtest/{self.backtest_id}/tick")
return r.json()

def insert_orders(self, orders):
Expand All @@ -26,7 +39,7 @@ def insert_orders(self, orders):

serialized_orders_str = ",".join([o.serialize() for o in orders])
val = f'{{"orders": [{serialized_orders_str}]}}'
r = requests.post(
r = self.s.post(
f"{self.base_url}/backtest/{self.backtest_id}/insert_orders",
data=val,
headers={"Content-type": "application/json"},
Expand All @@ -37,12 +50,12 @@ def info(self):
if self.backtest_id is None:
raise ValueError("Called before init")

r = requests.get(f"{self.base_url}/backtest/{self.backtest_id}/info")
r = self.s.get(f"{self.base_url}/backtest/{self.backtest_id}/info")
return r.json()

def now(self, backtest_id):
if self.backtest_id is None:
raise ValueError("Called before init")

r = requests.get(f"{self.base_url}/backtest/{self.backtest_id}/now")
r = self.s.get(f"{self.base_url}/backtest/{self.backtest_id}/now")
return r.json()
21 changes: 12 additions & 9 deletions rotala/src/exchange/uist_v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ impl UistV1 {
}
}

pub fn executed_trade_count(&self) -> usize {
self.trade_log.len()
}

fn sort_order_buffer(&mut self) {
self.order_buffer.sort_by(|a, _b| match a.get_order_type() {
OrderType::LimitSell | OrderType::StopSell | OrderType::MarketSell => {
Expand Down Expand Up @@ -372,8 +376,7 @@ mod tests {
exchange.tick(source.get_quotes_unchecked(&100));
exchange.tick(source.get_quotes_unchecked(&101));

//TODO: no abstraction!
assert_eq!(exchange.trade_log.len(), 1);
assert_eq!(exchange.executed_trade_count(), 1);
}

#[test]
Expand All @@ -387,7 +390,7 @@ mod tests {

exchange.tick(source.get_quotes_unchecked(&100));
exchange.tick(source.get_quotes_unchecked(&101));
assert_eq!(exchange.trade_log.len(), 4);
assert_eq!(exchange.executed_trade_count(), 4);
}

#[test]
Expand All @@ -402,7 +405,7 @@ mod tests {
exchange.tick(source.get_quotes_unchecked(&101));
exchange.tick(source.get_quotes_unchecked(&102));

assert_eq!(exchange.trade_log.len(), 4);
assert_eq!(exchange.executed_trade_count(), 4);
}

#[test]
Expand All @@ -414,7 +417,7 @@ mod tests {
exchange.tick(source.get_quotes_unchecked(&100));
exchange.tick(source.get_quotes_unchecked(&101));

assert_eq!(exchange.trade_log.len(), 1);
assert_eq!(exchange.executed_trade_count(), 1);
let trade = exchange.trade_log.remove(0);
//Trade executes at 101 so trade price should be 103
assert_eq!(trade.value / trade.quantity, 103.00);
Expand All @@ -430,7 +433,7 @@ mod tests {
exchange.tick(source.get_quotes_unchecked(&100));
exchange.tick(source.get_quotes_unchecked(&101));

assert_eq!(exchange.trade_log.len(), 1);
assert_eq!(exchange.executed_trade_count(), 1);
let trade = exchange.trade_log.remove(0);
//Trade executes at 101 so trade price should be 103
assert_eq!(trade.value / trade.quantity, 102.00);
Expand All @@ -444,7 +447,7 @@ mod tests {
exchange.insert_order(Order::market_buy("XYZ", 100.0));
exchange.tick(source.get_quotes_unchecked(&100));

assert_eq!(exchange.trade_log.len(), 0);
assert_eq!(exchange.executed_trade_count(), 0);
}

#[test]
Expand All @@ -468,11 +471,11 @@ mod tests {
exchange.insert_order(Order::market_buy("ABC", 100.0));
exchange.tick(source.get_quotes_unchecked(&100));
//Orderbook should have one order and trade log has no executed trades
assert_eq!(exchange.trade_log.len(), 0);
assert_eq!(exchange.executed_trade_count(), 0);

exchange.tick(source.get_quotes_unchecked(&102));
//Order should execute now
assert_eq!(exchange.trade_log.len(), 1);
assert_eq!(exchange.executed_trade_count(), 1);
}

#[test]
Expand Down
Loading

0 comments on commit a029b80

Please sign in to comment.