Skip to content

Commit

Permalink
Use functor to execute sacred experiments as optuna trials.
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Jan 12, 2024
1 parent 6ecfc34 commit 8828a35
Showing 1 changed file with 124 additions and 32 deletions.
156 changes: 124 additions & 32 deletions tuning/tune.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,69 @@
import argparse
import dataclasses
from typing import List, Mapping, Any, Callable, Dict

import optuna
import optuna.distributions as dist
import sacred

import imitation.scripts.train_preference_comparisons


def suggest_pc_run_params(trial: optuna.Trial) -> dict:
return dict(
@dataclasses.dataclass
class RunSacredAsTrial:
"""Runs a sacred experiment as an optuna trial.
Assumes that the sacred experiment returns a dict with a key 'imit_stats' that
contains a dict with a key 'monitor_return_mean'.
"""

"""The sacred experiment to run."""
sacred_ex: sacred.Experiment

"""A function that returns a list of named configs to pass to sacred.run."""
suggest_named_configs: Callable[[optuna.Trial], List[str]]

"""A function that returns a dict of config updates to pass to sacred.run."""
suggest_config_updates: Callable[[optuna.Trial], Mapping[str, Any]]

def __call__(
self,
trial: optuna.Trial,
run_options: Dict,
extra_named_configs: List[str]
) -> float:
"""Run the sacred experiment and return the performance.
Args:
trial: The optuna trial to sample hyperparameters for.
run_options: Options to pass to sacred.run(options=).
extra_named_configs: Additional named configs to pass to sacred.run.
"""

config_updates = self.suggest_config_updates(trial)
named_configs = self.suggest_named_configs(trial) + extra_named_configs

trial.set_user_attr("config_updates", config_updates)
trial.set_user_attr("named_configs", named_configs)

experiment: sacred.Experiment = self.sacred_ex
result = experiment.run(
config_updates=config_updates,
named_configs=named_configs,
options=run_options,
)
if result.status != "COMPLETED":
raise RuntimeError(
f"Trial failed with {result.fail_trace()} and status {result.status}."
)
return result.result['imit_stats']['monitor_return_mean']


"""A mapping from algorithm names to functions that run the algorithm as an optuna trial."""
objectives_by_algo = dict(
pc=RunSacredAsTrial(
sacred_ex=imitation.scripts.train_preference_comparisons.train_preference_comparisons_ex,
named_configs=["reward.reward_ensemble"],
config_updates={
suggest_named_configs=lambda _: ["reward.reward_ensemble"],
suggest_config_updates=lambda trial: {
"seed": trial.number,
"environment": {"num_vec": 1},
"total_timesteps": 2e7,
Expand Down Expand Up @@ -47,44 +99,84 @@ def suggest_pc_run_params(trial: optuna.Trial) -> dict:
},
},
},
)
),
)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--algo", type=str, default="pc")
def make_parser() -> argparse.ArgumentParser:
example_usage = "python -m imitation.scripts.tune pc seals_swimmer"
possible_named_configs = "\n".join(
f" - {algo}: {', '.join(objective.sacred_ex.named_configs.keys())}"
for algo, objective in objectives_by_algo.items()
)

args = parser.parse_args()
parser = argparse.ArgumentParser(
description="Tune hyperparameters for imitation learning algorithms.",
epilog=f"Example usage:\n{example_usage}\n\nPossible named configs:\n{possible_named_configs}",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"algo",
type=str,
default="pc",
choices=objectives_by_algo.keys(),
help="What algorithm to tune.",
)
parser.add_argument(
"named_configs",
type=str,
nargs="+",
default=[],
help="Additional named configs to pass to the sacred experiment. "
"Use this to select the environment to tune on.",
)
parser.add_argument(
"--num_trials",
type=int,
default=100,
help="Number of trials to run."
)
parser.add_argument(
"-j",
"--journal-log",
type=str,
default=None,
help="A journal file to synchronize multiple instances of this script. "
"Works on NFS storage."
)
return parser

if args.algo != "pc":
raise NotImplementedError(f"Tuning algorithm '{args.algo}' not implemented.")

study: optuna.Study = optuna.create_study(
study_name=f"tuning_{args.algo}"
def make_study(args: argparse.Namespace) -> optuna.Study:
if args.journal_log is not None:
storage = optuna.storages.JournalStorage(
optuna.storages.JournalFileStorage(args.journal_log)
)
else:
storage = None

return optuna.create_study(
study_name=f"tuning_{args.algo}_with_{'_'.join(args.named_configs)}",
storage=storage,
load_if_exists=True,
direction="maximize",
)

def objective(trial: optuna.Trial) -> float:
run_params = suggest_pc_run_params(trial)
trial.set_user_attr("config_updates", run_params["config_updates"])
trial.set_user_attr("named_configs", run_params["named_configs"])
experiment: sacred.Experiment = run_params["sacred_ex"]
result = experiment.run(
config_updates=run_params["config_updates"],
named_configs=run_params["named_configs"],
options={"--name": study.study_name, "--file_storage": "sacred"},
)
if result.status != "COMPLETED":
raise RuntimeError(
f"Trial failed with {result.fail_trace()} and status {result.status}."
)
return result.result['imit_stats']['monitor_return_mean']

def main():
parser = make_parser()
args = parser.parse_args()
study = make_study(args)

study.optimize(
objective,
callbacks=[optuna.study.MaxTrialsCallback(100)]
lambda trial: objectives_by_algo[args.algo](
trial,
run_options={"--name": study.study_name, "--file_storage": "sacred"},
extra_named_configs=args.named_configs
),
callbacks=[optuna.study.MaxTrialsCallback(args.num_trials)]
)


if __name__ == '__main__':
main()

0 comments on commit 8828a35

Please sign in to comment.