diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index 839a7a2a9f0..028e431e137 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -16,9 +16,12 @@ # Provide a project dir name, then each type of logger gets stored in project/{`logging_dir`} import os +import time from abc import ABCMeta, abstractmethod, abstractproperty from typing import List, Optional, Union +import yaml + from .logging import get_logger from .utils import LoggerType, is_comet_ml_available, is_tensorboard_available, is_wandb_available @@ -142,7 +145,8 @@ def tracker(self): def store_init_configuration(self, values: dict): """ - Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. + Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the + hyperparameters in a yaml file for future use. Args: values (Dictionary `str` to `bool`, `str`, `float` or `int`): @@ -151,7 +155,16 @@ def store_init_configuration(self, values: dict): """ self.writer.add_hparams(values, metric_dict={}) self.writer.flush() - logger.info("Stored initial configuration hyperparameters to TensorBoard") + project_run_name = time.time() + dir_name = os.path.join(self.logging_dir, str(project_run_name)) + os.makedirs(dir_name, exist_ok=True) + with open(os.path.join(dir_name, "hparams.yml"), "w") as outfile: + try: + yaml.dump(values, outfile) + except yaml.representer.RepresenterError: + logger.error("Serialization to store hyperparameters failed") + raise + logger.info("Stored initial configuration hyperparameters to TensorBoard and hparams yaml file") def log(self, values: dict, step: Optional[int] = None, **kwargs): """