From 3cd80c46b09671797552e842b3ac080f17101a70 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 1 Nov 2024 11:45:18 -0700 Subject: [PATCH] Only call the callbacks for rank==0 --- llms/mlx_lm/tuner/trainer.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 38619d956..1434d935f 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -258,13 +258,13 @@ def step(batch): flush=True, ) - if training_callback is not None: - val_info = { - "iteration": it, - "val_loss": val_loss, - "val_time": val_time, - } - training_callback.on_val_loss_report(val_info) + if training_callback is not None: + val_info = { + "iteration": it, + "val_loss": val_loss, + "val_time": val_time, + } + training_callback.on_val_loss_report(val_info) start = time.perf_counter() @@ -297,17 +297,17 @@ def step(batch): flush=True, ) - if training_callback is not None: - train_info = { - "iteration": it, - "train_loss": train_loss, - "learning_rate": learning_rate, - "iterations_per_second": it_sec, - "tokens_per_second": tokens_sec, - "trained_tokens": trained_tokens, - "peak_memory": peak_mem, - } - training_callback.on_train_loss_report(train_info) + if training_callback is not None: + train_info = { + "iteration": it, + "train_loss": train_loss, + "learning_rate": learning_rate, + "iterations_per_second": it_sec, + "tokens_per_second": tokens_sec, + "trained_tokens": trained_tokens, + "peak_memory": peak_mem, + } + training_callback.on_train_loss_report(train_info) losses = 0 n_tokens = 0