Skip to content

Commit

Permalink
Only call the callbacks for rank==0
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Nov 1, 2024
1 parent c5e09a1 commit 3cd80c4
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions llms/mlx_lm/tuner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ def step(batch):
flush=True,
)

if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)

start = time.perf_counter()

Expand Down Expand Up @@ -297,17 +297,17 @@ def step(batch):
flush=True,
)

if training_callback is not None:
train_info = {
"iteration": it,
"train_loss": train_loss,
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
}
training_callback.on_train_loss_report(train_info)
if training_callback is not None:
train_info = {
"iteration": it,
"train_loss": train_loss,
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
}
training_callback.on_train_loss_report(train_info)

losses = 0
n_tokens = 0
Expand Down

0 comments on commit 3cd80c4

Please sign in to comment.