Skip to content

Commit

Permalink
https://github.com/ml-explore/mlx-examples/pull/821
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
  • Loading branch information
yeahdongcn committed Aug 6, 2024
1 parent 0d208b3 commit e3de823
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
41 changes: 33 additions & 8 deletions llms/mlx_lm/tuner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
from mlx.utils import tree_flatten, tree_map


def grad_checkpoint(layer):
Expand All @@ -27,6 +27,17 @@ def inner_fn(params, *args, **kwargs):
type(layer).__call__ = checkpointed_fn


def average_gradients(gradients):
world_size = mx.distributed.init().size()
if world_size == 1:
return gradients

def _all_average(x):
return mx.distributed.all_sum(x) / world_size

return tree_map(_all_average, gradients)


@dataclass
class TrainingArgs:
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
Expand Down Expand Up @@ -82,9 +93,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
f" examples but only has {len(dataset)}."
)

# If running in distributed mode (N machines) then each one should skip N-1
# samples
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")

# Make the batches:
batch_idx = [
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size)
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]

while True:
Expand Down Expand Up @@ -112,9 +130,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)

batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)

for j in range(batch_size):
for j in range(batch_size // step):
truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
lengths[j] = (
Expand All @@ -138,7 +156,7 @@ def evaluate(
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
):
all_losses = []
all_losses = 0
ntokens = 0

index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
Expand All @@ -153,10 +171,14 @@ def evaluate(
),
):
losses, toks = loss(model, *batch)
all_losses.append((losses * toks).item())
ntokens += toks.item()
all_losses += losses * toks
ntokens += toks
mx.eval(all_losses, ntokens)

all_losses = mx.distributed.all_sum(all_losses)
ntokens = mx.distributed.all_sum(ntokens)

return np.sum(all_losses) / ntokens
return (all_losses / ntokens).item()


class TrainingCallback:
Expand Down Expand Up @@ -192,6 +214,9 @@ def step(batch):
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)

# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)

# Model update
optimizer.update(model, grad)

Expand Down
2 changes: 2 additions & 0 deletions lora/hostfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
10.0.0.2 slots=1
10.0.0.19 slots=1

0 comments on commit e3de823

Please sign in to comment.