Skip to content

Commit

Permalink
Revert "make run include the usage text dynamically"
Browse files Browse the repository at this point in the history
This reverts commit e4cd2a6.
  • Loading branch information
bpkroth committed Oct 9, 2024
1 parent e4cd2a6 commit 00d074d
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 164 deletions.
291 changes: 134 additions & 157 deletions mlos_bench/mlos_bench/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ class Launcher:
# pylint: disable=too-few-public-methods,too-many-instance-attributes
"""Command line launcher for mlos_bench and mlos_core."""

@staticmethod
def _get_parser(description: str, long_text: str) -> Tuple[argparse.ArgumentParser, List[str]]:
def __init__(self, description: str, long_text: str = "", argv: Optional[List[str]] = None):
# pylint: disable=too-many-statements
# pylint: disable=too-many-locals
_LOG.info("Launch: %s", description)
epilog = """
Additional --key=value pairs can be specified to augment or override
values listed in --globals.
Expand All @@ -54,8 +56,135 @@ def _get_parser(description: str, long_text: str) -> Tuple[argparse.ArgumentPars
the source tree:
<https://github.com/microsoft/MLOS/tree/main/mlos_bench/>
"""

parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog)
(args, path_args, args_rest) = self._parse_args(parser, argv)

# Bootstrap config loader: command line takes priority.
config_path = args.config_path or []
self._config_loader = ConfigPersistenceService({"config_path": config_path})
if args.config:
config = self._config_loader.load_config(args.config, ConfigSchema.CLI)
assert isinstance(config, Dict)
# Merge the args paths for the config loader with the paths from JSON file.
config_path += config.get("config_path", [])
self._config_loader = ConfigPersistenceService({"config_path": config_path})
else:
config = {}

log_level = args.log_level or config.get("log_level", _LOG_LEVEL)
try:
log_level = int(log_level)
except ValueError:
# failed to parse as an int - leave it as a string and let logging
# module handle whether it's an appropriate log name or not
log_level = logging.getLevelName(log_level)
logging.root.setLevel(log_level)
log_file = args.log_file or config.get("log_file")
if log_file:
log_handler = logging.FileHandler(log_file)
log_handler.setLevel(log_level)
log_handler.setFormatter(logging.Formatter(_LOG_FORMAT))
logging.root.addHandler(log_handler)

self._parent_service: Service = LocalExecService(parent=self._config_loader)

# Prepare global_config from a combination of global config files, cli
# configs, and cli args.
args_dict = vars(args)
# teardown (bool) conflicts with Environment configs that use it for shell
# commands (list), so we exclude it from copying over
excluded_cli_args = path_args + ["teardown"]
# Include (almost) any item from the cli config file that either isn't in
# the cli args at all or whose cli arg is missing.
cli_config_args = {
key: val
for (key, val) in config.items()
if (args_dict.get(key) is None) and key not in excluded_cli_args
}

self.global_config = self._load_config(
args_globals=config.get("globals", []) + (args.globals or []),
config_path=(args.config_path or []) + config.get("config_path", []),
args_rest=args_rest,
global_config=cli_config_args,
)
# experiment_id is generally taken from --globals files, but we also allow
# overriding it on the CLI.
# It's useful to keep it there explicitly mostly for the --help output.
if args.experiment_id:
self.global_config["experiment_id"] = args.experiment_id
# trial_config_repeat_count is a scheduler property but it's convenient to
# set it via command line
if args.trial_config_repeat_count:
self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count
# Ensure that the trial_id is present since it gets used by some other
# configs but is typically controlled by the run optimize loop.
self.global_config.setdefault("trial_id", 1)

self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True)
assert isinstance(self.global_config, dict)

# --service cli args should override the config file values.
service_files: List[str] = config.get("services", []) + (args.service or [])
assert isinstance(self._parent_service, SupportsConfigLoading)
self._parent_service = self._parent_service.load_services(
service_files,
self.global_config,
self._parent_service,
)

env_path = args.environment or config.get("environment")
if not env_path:
_LOG.error("No environment config specified.")
parser.error(
"At least the Environment config must be specified."
+ " Run `mlos_bench --help` and consult `README.md` for more info."
)
self.root_env_config = self._config_loader.resolve_path(env_path)

self.environment: Environment = self._config_loader.load_environment(
self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service
)
_LOG.info("Init environment: %s", self.environment)

# NOTE: Init tunable values *after* the Environment, but *before* the Optimizer
self.tunables = self._init_tunable_values(
args.random_init or config.get("random_init", False),
config.get("random_seed") if args.random_seed is None else args.random_seed,
config.get("tunable_values", []) + (args.tunable_values or []),
)
_LOG.info("Init tunables: %s", self.tunables)

self.optimizer = self._load_optimizer(args.optimizer or config.get("optimizer"))
_LOG.info("Init optimizer: %s", self.optimizer)

self.storage = self._load_storage(args.storage or config.get("storage"))
_LOG.info("Init storage: %s", self.storage)

self.teardown: bool = (
bool(args.teardown)
if args.teardown is not None
else bool(config.get("teardown", True))
)
self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler"))
_LOG.info("Init scheduler: %s", self.scheduler)

@property
def config_loader(self) -> ConfigPersistenceService:
"""Get the config loader service."""
return self._config_loader

@property
def service(self) -> Service:
"""Get the parent service."""
return self._parent_service

@staticmethod
def _parse_args(
parser: argparse.ArgumentParser,
argv: Optional[List[str]],
) -> Tuple[argparse.Namespace, List[str], List[str]]:
"""Parse the command line arguments."""

class PathArgsTracker:
"""Simple class to help track which arguments are paths."""
Expand Down Expand Up @@ -235,166 +364,14 @@ def add_argument(self, *args: Any, **kwargs: Any) -> None:
"incompatible" is not easily automatable across systems.
""",
)
return (parser, path_args_tracker.path_args)

@staticmethod
def get_help_text(description: str, long_text: str) -> str:
"""Gets the help text from the argument parser.
Parameters
----------
description : str
The short name of the script.
long_text : str
The long description of the script.
Returns
-------
str
The help text from the argument parser.
"""
(parser, _path_args) = Launcher._get_parser(description, long_text)
return parser.format_help()

def __init__(self, description: str, long_text: str = "", argv: Optional[List[str]] = None):
# pylint: disable=too-many-statements
# pylint: disable=too-many-locals
_LOG.info("Launch: %s", description)
(parser, path_args) = self._get_parser(description, long_text)
(args, args_rest) = self._parse_args(parser, argv)

# Bootstrap config loader: command line takes priority.
config_path = args.config_path or []
self._config_loader = ConfigPersistenceService({"config_path": config_path})
if args.config:
config = self._config_loader.load_config(args.config, ConfigSchema.CLI)
assert isinstance(config, Dict)
# Merge the args paths for the config loader with the paths from JSON file.
config_path += config.get("config_path", [])
self._config_loader = ConfigPersistenceService({"config_path": config_path})
else:
config = {}

log_level = args.log_level or config.get("log_level", _LOG_LEVEL)
try:
log_level = int(log_level)
except ValueError:
# failed to parse as an int - leave it as a string and let logging
# module handle whether it's an appropriate log name or not
log_level = logging.getLevelName(log_level)
logging.root.setLevel(log_level)
log_file = args.log_file or config.get("log_file")
if log_file:
log_handler = logging.FileHandler(log_file)
log_handler.setLevel(log_level)
log_handler.setFormatter(logging.Formatter(_LOG_FORMAT))
logging.root.addHandler(log_handler)

self._parent_service: Service = LocalExecService(parent=self._config_loader)

# Prepare global_config from a combination of global config files, cli
# configs, and cli args.
args_dict = vars(args)
# teardown (bool) conflicts with Environment configs that use it for shell
# commands (list), so we exclude it from copying over
excluded_cli_args = path_args + ["teardown"]
# Include (almost) any item from the cli config file that either isn't in
# the cli args at all or whose cli arg is missing.
cli_config_args = {
key: val
for (key, val) in config.items()
if (args_dict.get(key) is None) and key not in excluded_cli_args
}

self.global_config = self._load_config(
args_globals=config.get("globals", []) + (args.globals or []),
config_path=(args.config_path or []) + config.get("config_path", []),
args_rest=args_rest,
global_config=cli_config_args,
)
# experiment_id is generally taken from --globals files, but we also allow
# overriding it on the CLI.
# It's useful to keep it there explicitly mostly for the --help output.
if args.experiment_id:
self.global_config["experiment_id"] = args.experiment_id
# trial_config_repeat_count is a scheduler property but it's convenient to
# set it via command line
if args.trial_config_repeat_count:
self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count
# Ensure that the trial_id is present since it gets used by some other
# configs but is typically controlled by the run optimize loop.
self.global_config.setdefault("trial_id", 1)

self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True)
assert isinstance(self.global_config, dict)

# --service cli args should override the config file values.
service_files: List[str] = config.get("services", []) + (args.service or [])
assert isinstance(self._parent_service, SupportsConfigLoading)
self._parent_service = self._parent_service.load_services(
service_files,
self.global_config,
self._parent_service,
)

env_path = args.environment or config.get("environment")
if not env_path:
_LOG.error("No environment config specified.")
parser.error(
"At least the Environment config must be specified."
+ " Run `mlos_bench --help` and consult `README.md` for more info."
)
self.root_env_config = self._config_loader.resolve_path(env_path)

self.environment: Environment = self._config_loader.load_environment(
self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service
)
_LOG.info("Init environment: %s", self.environment)

# NOTE: Init tunable values *after* the Environment, but *before* the Optimizer
self.tunables = self._init_tunable_values(
args.random_init or config.get("random_init", False),
config.get("random_seed") if args.random_seed is None else args.random_seed,
config.get("tunable_values", []) + (args.tunable_values or []),
)
_LOG.info("Init tunables: %s", self.tunables)

self.optimizer = self._load_optimizer(args.optimizer or config.get("optimizer"))
_LOG.info("Init optimizer: %s", self.optimizer)

self.storage = self._load_storage(args.storage or config.get("storage"))
_LOG.info("Init storage: %s", self.storage)

self.teardown: bool = (
bool(args.teardown)
if args.teardown is not None
else bool(config.get("teardown", True))
)
self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler"))
_LOG.info("Init scheduler: %s", self.scheduler)

@property
def config_loader(self) -> ConfigPersistenceService:
"""Get the config loader service."""
return self._config_loader

@property
def service(self) -> Service:
"""Get the parent service."""
return self._parent_service

@staticmethod
def _parse_args(
parser: argparse.ArgumentParser,
argv: Optional[List[str]],
) -> Tuple[argparse.Namespace, List[str]]:
"""Parse the command line arguments."""
# By default we use the command line arguments, but allow the caller to
# provide some explicitly for testing purposes.
if argv is None:
argv = sys.argv[1:].copy()
(args, args_rest) = parser.parse_known_args(argv)
return (args, args_rest)

return (args, path_args_tracker.path_args, args_rest)

@staticmethod
def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]:
Expand Down
8 changes: 1 addition & 7 deletions mlos_bench/mlos_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@
_LOG = logging.getLogger(__name__)


_NAME = "mlos_bench"
_DESC = "Systems autotuning and benchmarking tool"
# Dynamically add the --help text to our docstring.
__doc__ += "\n" + Launcher.get_help_text(_NAME, _DESC)


def _sanity_check_results(launcher: Launcher) -> None:
"""Do some sanity checking on the results and throw an exception if it looks like
something went wrong.
Expand Down Expand Up @@ -63,7 +57,7 @@ def _sanity_check_results(launcher: Launcher) -> None:
def _main(
argv: Optional[List[str]] = None,
) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]:
launcher = Launcher(_NAME, _DESC, argv=argv)
launcher = Launcher("mlos_bench", "Systems autotuning and benchmarking tool", argv=argv)

with launcher.scheduler as scheduler_context:
scheduler_context.start()
Expand Down

0 comments on commit 00d074d

Please sign in to comment.