diff --git a/farm/experiment.py b/farm/experiment.py index ab16ddc4a..4d8a1cb18 100644 --- a/farm/experiment.py +++ b/farm/experiment.py @@ -42,6 +42,19 @@ def load_experiments(file): def run_experiment(args): + + logger.info( + "\n***********************************************" + f"\n************* Experiment: {args.task.name} ************" + "\n************************************************" + ) + ml_logger = MlLogger(tracking_uri=args.logging.mlflow_url) + ml_logger.init_experiment( + experiment_name=args.logging.mlflow_experiment, + run_name=args.logging.mlflow_run_name, + nested=args.logging.mlflow_nested, + ) + validate_args(args) distributed = bool(args.general.local_rank != -1) diff --git a/run_all_experiments.py b/run_all_experiments.py index f8a092378..d895a1568 100644 --- a/run_all_experiments.py +++ b/run_all_experiments.py @@ -13,11 +13,7 @@ # limitations under the License. """Downstream runner for all experiments in specified config files.""" -import logging from farm.experiment import run_experiment, load_experiments -from farm.utils import MLFlowLogger - -logger = logging.getLogger(__name__) def main(): @@ -32,20 +28,8 @@ def main(): for conf_file in config_files: experiments = load_experiments(conf_file) - for args in experiments: - logger.info( - "\n***********************************************" - f"\n************* Experiment: {args.task.name} ************" - "\n************************************************" - ) - ml_logger = MLFlowLogger(tracking_uri=args.logging.mlflow_url) - ml_logger.init_experiment( - experiment_name=args.logging.mlflow_experiment, - run_name=args.logging.mlflow_run_name, - nested=args.logging.mlflow_nested, - ) - run_experiment(args) - + for experiment in experiments: + run_experiment(experiment) if __name__ == "__main__": main()