Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
scheuclu committed Sep 18, 2024
1 parent 0d3d671 commit edaba21
Show file tree
Hide file tree
Showing 18 changed files with 691 additions and 342 deletions.
122 changes: 122 additions & 0 deletions cli_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import argparse
import logging
import os
from datetime import timedelta
from typing import Callable, Literal, Optional, Protocol, Union

from mypy_extensions import NamedArg
from pytimeparse.timeparse import timeparse

Main = Callable[
[
NamedArg(Optional[int], "dashboard_server_port"),
NamedArg(bool, "simulation_status_bar"),
NamedArg(bool, "auto_close"),
NamedArg(Optional[timedelta], "run_length"),
],
None,
]

from enum import Enum


class CliLogLevel(Enum):
CRITICAL = 50
FATAL = CRITICAL
ERROR = 40
WARNING = 30
WARN = WARNING
INFO = 20
DEBUG = 10
NOTSET = 0

def __str__(self) -> str:
return logging._levelToName[self.value]

@staticmethod
def of_string(s: str) -> "CliLogLevel":
s = s.upper()
try:
return CliLogLevel[s]
except:
raise Exception(f"{s=} is not a valid log level!")


class DojoLogFilter(logging.Filter):
def __init__(self) -> None:
super().__init__("dojo")


class NotLogFilter(logging.Filter):
def __init__(self, inner: logging.Filter):
self.inner = inner

def filter(self, record: logging.LogRecord) -> bool:
return not (self.inner.filter(record))


def run_main(main_f: Main) -> None:
parser = argparse.ArgumentParser(description="Run a Dojo Simulation")
default_dashboard_server_port = 8786
dashboard_group = parser.add_mutually_exclusive_group()
dashboard_group.add_argument(
"--dashboard-server-port",
type=int,
default=default_dashboard_server_port,
help=f"The port the dashboard should be server on. {default_dashboard_server_port=}",
)
dashboard_group.add_argument("--no-dashboard", action="store_true")
parser.add_argument(
"--log-level",
type=CliLogLevel.of_string,
choices=list(CliLogLevel),
default=CliLogLevel.INFO,
help="log level for dojo",
)
parser.add_argument(
"--global-log-level",
type=CliLogLevel.of_string,
choices=list(CliLogLevel),
default=CliLogLevel.ERROR,
help="log level for all libraries other than dojo",
)
parser.add_argument("--simulation-status-bar", type=bool, default=False)
parser.add_argument("--auto-close", type=bool, default=True)
parser.add_argument(
"--run-length",
type=lambda s: timedelta(seconds=timeparse(s)),
default=None,
help="parsed using pytimeparse (e.g. you can use '1h')",
)
args = parser.parse_args()

main_handler = logging.StreamHandler()
main_handler.setLevel(args.global_log_level.value)
main_handler.addFilter(NotLogFilter(DojoLogFilter()))
dojo_handler = logging.StreamHandler()
dojo_handler.setLevel(args.log_level.value)
dojo_handler.addFilter(DojoLogFilter())
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s",
level=logging.NOTSET,
handlers=[dojo_handler, main_handler],
)

if args.no_dashboard:
dashboard_server_port: Optional[int] = None
else:
dashboard_server_port = args.dashboard_server_port

call_args = {
"dashboard_server_port": dashboard_server_port,
"simulation_status_bar": args.simulation_status_bar,
"auto_close": args.auto_close,
"run_length": args.run_length,
}
call_args = {
# let main functions use default values for run_length - we don't want to override them with the 'None' value
k: v
for k, v in call_args.items()
if not (k == "run_length" and v == None)
}
main_f(**call_args)
53 changes: 17 additions & 36 deletions example_backtest.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
import argparse
import logging
import os
import sys
from datetime import timedelta
from decimal import Decimal
from typing import Union

# Example logging configuration which:
# - by default, it logs only INFO
# - logs DEBUG messages only from "dojo.network"
# For more config options, see: https://docs.python.org/3.12/howto/logging-cookbook.html#customizing-handlers-with-dictconfig
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s",
level=logging.ERROR,
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
from typing import Optional

import cli_runner
from agents.uniswapV3_pool_wealth import UniswapV3PoolWealthAgent
from dateutil import parser as dateparser
from examples.moving_averages.policy import MovingAveragePolicy
Expand All @@ -24,15 +16,21 @@
from dojo.runners import backtest_run


def main(dashboard_server_port: Union[int, None]) -> None:
def main(
*,
dashboard_server_port: Optional[int],
simulation_status_bar: bool,
auto_close: bool,
run_length: timedelta = timedelta(hours=1),
) -> None:

# SNIPPET 1 START
# pools = ["USDC/WETH-0.3"]
# start_time = dateparser.parse("2021-06-22 00:00:00 UTC")
# end_time = dateparser.parse("2021-06-22 12:0:00 UTC")
pools = ["USDC/WETH-0.05"]
start_time = dateparser.parse("2022-06-21 00:00:00 UTC")
end_time = dateparser.parse("2022-06-21 01:00:00 UTC")
end_time = start_time + run_length

# Agents
agent1 = UniswapV3PoolWealthAgent(
Expand Down Expand Up @@ -71,29 +69,12 @@ def main(dashboard_server_port: Union[int, None]) -> None:
backtest_run(
env,
[mvag_policy, passive_lp_policy],
dashboard_server_port=dashboard_server_port
if dashboard_server_port != -1
else None,
output_dir="./",
auto_close=True,
simulation_status_bar=True,
dashboard_server_port=dashboard_server_port,
auto_close=auto_close,
simulation_status_bar=simulation_status_bar,
)
# SNIPPET 1 END


if __name__ == "__main__":
# Adding a boolean flag
parser = argparse.ArgumentParser(
description="Example script for boolean argument with argparse"
)

# Adding an optional integer argument with a default value
parser.add_argument(
"--dashboard_server_port",
type=int,
default=8786,
help="What port the data should be served on. Can be None, in which case the data is not being served.",
)

args = parser.parse_args()
main(dashboard_server_port=args.dashboard_server_port)
cli_runner.run_main(main)
33 changes: 21 additions & 12 deletions examples/aavev3/run.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# import logging
import logging
import os
import sys
from datetime import timedelta
from decimal import Decimal
from typing import Optional

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))

import cli_runner
from dateutil import parser as dateparser
from policy import AAVEv3Policy

Expand All @@ -12,10 +17,6 @@
from dojo.environments.aaveV3 import AAVEv3Observation
from dojo.runners import backtest_run

# logging.basicConfig(
# format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
# )


class ConstantRewardAgent(AAVEv3Agent):
"""An agent that does not have any particular objective."""
Expand All @@ -32,9 +33,15 @@ def reward(self, obs: AAVEv3Observation) -> float: # type: ignore
return obs.get_user_account_data_base(self.original_address).healthFactor # type: ignore[arg-type]


def main() -> None:
def main(
*,
dashboard_server_port: Optional[int],
simulation_status_bar: bool,
auto_close: bool,
run_length: timedelta = timedelta(hours=6),
) -> None:
start_time = dateparser.parse("2023-03-11 00:00:00 UTC")
end_time = dateparser.parse("2023-03-11 06:00:00 UTC")
end_time = start_time + run_length
# Agents
agent1 = ConstantRewardAgent(
initial_portfolio={
Expand All @@ -60,12 +67,14 @@ def main() -> None:
backtest_run(
env=env,
policies=[policy],
dashboard_server_port=8051,
output_dir="./",
auto_close=True,
simulation_status_bar=True,
dashboard_server_port=dashboard_server_port,
output_file="aavev3.db",
auto_close=auto_close,
simulation_status_bar=simulation_status_bar,
simulation_title="AAVE strategy",
simulation_description="This example is maintaining a position between 2 health factor thresholds.",
)


if __name__ == "__main__":
main()
cli_runner.run_main(main)
42 changes: 24 additions & 18 deletions examples/active_lp/run.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,33 @@
import logging
import os
import sys
from datetime import timedelta
from decimal import Decimal
from typing import Optional

from dateutil import parser as dateparser

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))

import cli_runner
from agents.uniswapV3_pool_wealth import UniswapV3PoolWealthAgent
from dateutil import parser as dateparser
from examples.active_lp.policy import ActiveConcentratedLP

from dojo.common.constants import Chain
from dojo.environments import UniswapV3Env
from dojo.runners import backtest_run

# Example logging configuration which:
# - by default, it logs only INFO
# - logs DEBUG messages only from "dojo.network"
# For more config options, see: https://docs.python.org/3.12/howto/logging-cookbook.html#customizing-handlers-with-dictconfig
logging.getLogger("dojo.network").setLevel(logging.DEBUG)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s",
level=logging.INFO,
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)


def main() -> None:
def main(
*,
dashboard_server_port: Optional[int],
simulation_status_bar: bool,
auto_close: bool,
run_length: timedelta = timedelta(days=1),
) -> None:
# SNIPPET 1 START
pools = ["USDC/WETH-0.05"]
start_time = dateparser.parse("2023-05-01 00:00:00 UTC")
end_time = dateparser.parse("2023-05-02 00:00:00 UTC")
end_time = start_time + run_length

agent2 = UniswapV3PoolWealthAgent(
initial_portfolio={
Expand All @@ -53,9 +50,18 @@ def main() -> None:

active_lp_policy = ActiveConcentratedLP(agent=agent2, lp_width=2)

backtest_run(env, [active_lp_policy], dashboard_server_port=8051, auto_close=True)
backtest_run(
env,
[active_lp_policy],
dashboard_server_port=dashboard_server_port,
output_file="active_lp.db",
auto_close=auto_close,
simulation_status_bar=simulation_status_bar,
simulation_title="Active liquidity provisioning",
simulation_description="Keep liquidity in the active tick range.",
)
# SNIPPET 1 END


if __name__ == "__main__":
main()
cli_runner.run_main(main)
Loading

0 comments on commit edaba21

Please sign in to comment.