Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
  • Loading branch information
abbass2 committed Oct 27, 2023
1 parent 1046f5f commit ba7f7b3
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 49 deletions.
8 changes: 3 additions & 5 deletions pyqstrat/notebooks/multiple_contracts.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@
" on_ret: float # overnight return\n",
" timestamps: np.ndarray\n",
" prices: np.ndarray\n",
" \n",
" \n",
"\n",
" \n",
"def create_overnight_returns(contracts: list[str]) -> dict[np.datetime64, OvernightReturn]:\n",
Expand All @@ -108,9 +106,6 @@
" date_prices = prices[prices.date == date]\n",
" on_rets[date] = OvernightReturn(name, on_ret[i], date_prices.timestamp.values.astype('M8[m]'), date_prices.c.values)\n",
" return on_rets\n",
" \n",
" \n",
"\n",
"\n",
"\n",
"def create_price_dataframe(on_rets: dict[np.datetime64, OvernightReturn]) -> pd.DataFrame:\n",
Expand Down Expand Up @@ -152,6 +147,9 @@
"\n",
"@dataclass\n",
"class ContractFilter:\n",
" '''\n",
" For each day we want to trade only one symbol. So don't allow trade entry for all others\n",
" '''\n",
" def __init__(self, entry_contracts: dict[np.datetime64, list[str]]) -> None:\n",
" self.entry_contracts = entry_contracts\n",
" \n",
Expand Down
20 changes: 18 additions & 2 deletions pyqstrat/pq_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,11 @@ def get_contracts(self) -> list[Contract]:

@staticmethod
def clear_cache() -> None:
ContractGroup.contracts = {}
ContractGroup.contracts = {}
ContractGroup._instances = {}

def clear(self) -> None:
'''Remove all contracts'''
self.contracts.clear()

def __repr__(self) -> str:
return self.name
Expand Down Expand Up @@ -132,6 +135,19 @@ def get(name) -> Contract | None:
'''
return Contract._instances.get(name)

@staticmethod
def get_or_create(symbol: str,
contract_group: ContractGroup | None = None,
expiry: np.datetime64 | None = None,
multiplier: float = 1.,
components: list[tuple[Contract, float]] | None = None,
properties: SimpleNamespace | None = None) -> Contract:
if symbol in Contract._instances:
contract = Contract._instances.get(symbol)
else:
contract = Contract.create(symbol, contract_group, expiry, multiplier, components, properties)
return contract # type: ignore

@staticmethod
def clear_cache() -> None:
'''
Expand Down
117 changes: 75 additions & 42 deletions pyqstrat/strategy_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_contract_price_from_array_dict(price_dict: dict[str, tuple[np.ndarray, n
idx = np_indexof_sorted(_timestamps, timestamp)
if idx == -1: return math.nan
return tup[1][idx] # type: ignore


@dataclass
class PriceFuncArrayDict:
Expand Down Expand Up @@ -158,7 +158,7 @@ class SimpleMarketSimulator:
>>> timestamp = np.datetime64('2023-01-03 14:35')
>>> price_func = PriceFuncDict({put_symbol: {timestamp: 4.8}, call_symbol: {timestamp: 3.5}})
>>> order = MarketOrder(contract=basket, timestamp=timestamp, qty=10, reason_code='TEST')
>>> sim = SimpleMarketSimulator(price_func=price_func, slippage_per_trade=0)
>>> sim = SimpleMarketSimulator(price_func=price_func, slippage=0)
>>> out = sim([order], 0, np.array([timestamp]), {}, {}, SimpleNamespace())
>>> assert(len(out) == 1)
>>> assert(math.isclose(out[0].price, -1.3))
Expand All @@ -167,21 +167,24 @@ class SimpleMarketSimulator:
price_func: PriceFunctionType
slippage: float
commission: float
post_trade_func: Callable[[Trade, StrategyContextType], None] | None

def __init__(self,
price_func: PriceFunctionType,
slippage_per_trade: float = 0.,
commission_per_trade: float = 0) -> None:
slippage: float = 0.,
commission: float = 0,
post_trade_func: Callable[[Trade, StrategyContextType], None] | None = None) -> None:
'''
Args:
price_func: A function that we use to get the price to execute at
slippage_per_trade: Slippage in local currency. Meant to simulate the difference
between bid/ask mid and execution price
commission_per_trade: Fee paid to broker per trade
slippage: Slippage per dollar transacted.
Meant to simulate the difference between bid/ask mid and execution price
commission: Fee paid to broker per trade
'''
self.price_func: PriceFunctionType = price_func
self.slippage: float = slippage_per_trade
self.commission: float = commission_per_trade
self.slippage: float = slippage
self.commission: float = commission
self.post_trade_func = post_trade_func

def __call__(self,
orders: Sequence[Order],
Expand All @@ -190,7 +193,7 @@ def __call__(self,
indicators: dict[ContractGroup, SimpleNamespace],
signals: dict[ContractGroup, SimpleNamespace],
strategy_context: SimpleNamespace) -> list[Trade]:
'''TODO: code for limit orders and stop orders'''
'''TODO: code for stop orders'''
trades = []
timestamp = timestamps[i]
# _logger.info(f'got: {orders}')
Expand All @@ -206,7 +209,7 @@ def __call__(self,
if np.isnan(raw_price):
break
if np.isnan(raw_price): continue
slippage = self.slippage * order.qty
slippage = self.slippage * order.qty * raw_price
if order.qty < 0: slippage = -slippage
price = raw_price + slippage
if isinstance(order, LimitOrder) and np.isfinite(order.limit_price):
Expand All @@ -219,6 +222,8 @@ def __call__(self,
_logger.info(f'TRADE: {timestamp.astype("M8[m]")} {trade}')
order.fill()
trades.append(trade)
if self.post_trade_func is not None:
self.post_trade_func(trade, strategy_context)
return trades


Expand Down Expand Up @@ -509,15 +514,15 @@ def __call__(self,


ContractFilterType = Callable[
[Contract,
[ContractGroup,
int,
np.ndarray,
SimpleNamespace,
np.ndarray,
Account,
Sequence[Order],
StrategyContextType],
list[str] | None]
list[str]]


@dataclass
Expand All @@ -533,36 +538,60 @@ class FiniteRiskEntryRule:
Used to calculate order qty so that if we get stopped out, we don't lose
more than this amount. Of course if price gaps up or down rather than moving smoothly,
we may lose more.
stop_price_ind: An indicator containing the stop price so we exit the order when this is breached
stop_price_func: An indicator containing the stop price so we exit the order when this is breached
contract_filter: A function that takes similar arguments as a rule (with ContractGroup) replaced by
Contract but returns a list of contract names for each positive signal timestamp. For example,
for a strategy that trades 5000 stocks, you may want to construct a single signal and apply it
to different contracts at different times, rather than create 5000 signals that will call your rule
5000 times every time the signal is true.
>>> timestamps = np.arange(np.datetime64('2023-01-01'), np.datetime64('2023-01-05'))
>>> sig_values = np.full(len(timestamps), False)
>>> aapl_prices = np.array([100.1, 100.2, 100.3, 100.4])
>>> ibm_prices = np.array([200.1, 200.2, 200.3, 200.4])
>>> stops = np.array([50, 75, 85, 100.35])
>>> price_dict = {'AAPL': (timestamps, aapl_prices), 'IBM': (timestamps, ibm_prices)}
>>> stop_dict = {'AAPL': (timestamps, stops), 'IBM': (timestamps, stops)}
>>> price_func = PriceFuncArrayDict(price_dict)
>>> fr = FiniteRiskEntryRule('TEST_ENTRY', price_func, False)
>>> default_cg = ContractGroup.get('DEFAULT')
>>> default_cg.clear_cache()
>>> default_cg.add_contract(Contract.get_or_create('AAPL'))
>>> default_cg.add_contract(Contract.get_or_create('IBM'))
>>> account = SimpleNamespace()
>>> account.equity = lambda x: 1e6
>>> orders = fr(default_cg, 1, timestamps, SimpleNamespace(), sig_values, account, [], SimpleNamespace())
>>> assert len(orders) == 2 and orders[0].qty == 998 and orders[1].qty == 499
>>> stop_price_func = PriceFuncArrayDict(stop_dict)
>>> fr = FiniteRiskEntryRule('TEST_ENTRY', price_func, long=True, stop_price_func=stop_price_func, min_price_diff=0.1)
>>> orders = fr(default_cg, 2, timestamps, SimpleNamespace(), sig_values, account, [], SimpleNamespace())
>>> assert len(orders) == 2 and orders[0].qty == 6535 and orders[1].qty == 867
>>> orders = fr(default_cg, 3, timestamps, SimpleNamespace(), sig_values, account, [], SimpleNamespace())
>>> assert len(orders) == 1 and orders[0].qty == 999
'''
reason_code: str
price_func: PriceFunctionType
long: bool
percent_of_equity: float
stop_price_ind: str | None
min_price_diff: float
single_entry_per_day: bool
contract_filter: ContractFilterType | None

stop_price_func: PriceFunctionType | None

def __init__(self,
reason_code: str,
price_func: PriceFunctionType,
long: bool = True,
percent_of_equity: float = 0.1,
stop_price_ind: str | None = None,
min_price_diff: float = 0,
single_entry_per_day: bool = False,
contract_filter: ContractFilterType | None = None) -> None:
contract_filter: ContractFilterType | None = None,
stop_price_func: PriceFunctionType | None = None) -> None:
self.reason_code = reason_code
self.price_func = price_func
self.long = long
self.percent_of_equity = percent_of_equity
self.stop_price_ind = stop_price_ind
self.stop_price_func = stop_price_func
self.min_price_diff = min_price_diff
self.single_entry_per_day = single_entry_per_day
self.contract_filter = contract_filter
Expand All @@ -576,38 +605,43 @@ def __call__(self,
account: Account,
current_orders: Sequence[Order],
strategy_context: StrategyContextType) -> list[Order]:
# import pdb; pdb.set_trace()
timestamp = timestamps[i]
if self.single_entry_per_day:
date = timestamp.astype('M8[D]')
trades = account.get_trades_for_date(contract_group.name, date)
if len(trades): return []
date = timestamp.astype('M8[D]')

contracts = contract_group.get_contracts()
contracts: list[Contract] = []
if self.contract_filter is not None:
names = self.contract_filter(
contract_group, i, timestamps, indicator_values, signal_values, account, current_orders, strategy_context)
for name in names:
_contract = Contract.get(name)
if _contract is None: continue
contracts.append(_contract)
else:
contracts = contract_group.get_contracts()

orders: list[Order] = []
for contract in contracts:
if self.contract_filter is not None:
relevant_contracts = self.contract_filter(
contract, i, timestamps, indicator_values, signal_values, account, current_orders, strategy_context)
if relevant_contracts is None or contract.symbol not in relevant_contracts: continue

if self.single_entry_per_day:
trades = account.get_trades_for_date(contract_group.name, date)
if len(trades): continue

entry_price_est = self.price_func(contract, timestamps, i, strategy_context) # type: ignore
if math.isnan(entry_price_est): return []

if self.stop_price_ind:
_stop_price_ind = getattr(indicator_values, self.stop_price_ind)
stop_price = _stop_price_ind[i]
else:
stop_price = 0.
if math.isnan(entry_price_est): continue

if self.long and (entry_price_est - stop_price) < self.min_price_diff: return []
if not self.long and (stop_price - entry_price_est) < self.min_price_diff: return []
stop_price = 0.
if self.stop_price_func is not None:
stop_price = self.stop_price_func(contract, timestamps, i, strategy_context) # type: ignore
if abs(entry_price_est - stop_price) < self.min_price_diff:
_logger.info(f'entry price estimate: {entry_price_est} too close to stop price: {stop_price}'
f' tolerance: {self.min_price_diff}')
continue

curr_equity = account.equity(timestamp)
risk_amount = self.percent_of_equity * curr_equity
order_qty = risk_amount / (entry_price_est - stop_price)
order_qty /= len(contracts) # divide up equity equally
order_qty = math.floor(order_qty) if order_qty > 0 else math.ceil(order_qty)
if math.isclose(order_qty, 0.): return []
if math.isclose(order_qty, 0.): continue
order = MarketOrder(contract=contract, # type: ignore
timestamp=timestamp,
qty=order_qty,
Expand Down Expand Up @@ -639,8 +673,7 @@ def __init__(self,
log_orders: bool = True) -> None:
self.reason_code = reason_code
self.price_func = price_func
assert_(math.isnan(limit_increment) or limit_increment >= 0,
f'limit_increment: {limit_increment} cannot be negative')
assert_(math.isnan(limit_increment) or limit_increment >= 0, f'limit_increment: {limit_increment} cannot be negative')
self.limit_increment = limit_increment
self.log_orders = log_orders

Expand Down

0 comments on commit ba7f7b3

Please sign in to comment.