Skip to content

Commit

Permalink
Enable distributed LoRA training (#821)
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath authored Nov 3, 2024
1 parent 29c954f commit 331148d
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 46 deletions.
81 changes: 55 additions & 26 deletions llms/mlx_lm/tuner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten


Expand Down Expand Up @@ -84,9 +85,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
f" examples but only has {len(dataset)}."
)

# If running in distributed mode (N machines) then each one should skip N-1
# samples
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")

# Make the batches:
batch_idx = [
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size)
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]

while True:
Expand All @@ -112,9 +120,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)

batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)

for j in range(batch_size):
for j in range(batch_size // step):
truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
lengths[j] = (
Expand All @@ -138,7 +146,7 @@ def evaluate(
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
):
all_losses = []
all_losses = 0
ntokens = 0

index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
Expand All @@ -153,10 +161,14 @@ def evaluate(
),
):
losses, toks = loss(model, *batch)
all_losses.append((losses * toks).item())
ntokens += toks.item()
all_losses += losses * toks
ntokens += toks
mx.eval(all_losses, ntokens)

all_losses = mx.distributed.all_sum(all_losses)
ntokens = mx.distributed.all_sum(ntokens)

return np.sum(all_losses) / ntokens
return (all_losses / ntokens).item()


class TrainingCallback:
Expand All @@ -182,6 +194,11 @@ def train(
training_callback: TrainingCallback = None,
):
print(f"Starting training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")

if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
Expand All @@ -192,15 +209,19 @@ def step(batch):
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)

# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)

# Model update
optimizer.update(model, grad)

return lvalue, toks

loss_value_and_grad = nn.value_and_grad(model, loss)

losses = []
losses = 0
n_tokens = 0
steps = 0
trained_tokens = 0
# Main training loop
start = time.perf_counter()
Expand Down Expand Up @@ -229,9 +250,13 @@ def step(batch):
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
print(
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
)
if rank == 0:
print(
f"Iter {it}: "
f"Val loss {val_loss:.3f}, "
f"Val took {val_time:.3f}s",
flush=True,
)

if training_callback is not None:
val_info = {
Expand All @@ -244,30 +269,33 @@ def step(batch):
start = time.perf_counter()

lvalue, toks = step(batch)
mx.eval(state, lvalue, toks)

# Record loss
losses.append(lvalue.item())
n_tokens += toks.item()
losses += lvalue
n_tokens += toks
steps += 1
mx.eval(state, losses, n_tokens)

# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()

train_loss = np.mean(losses)
train_loss = mx.distributed.all_sum(losses).item()
train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 2**30
print(
f"Iter {it}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB"
)
if rank == 0:
print(
f"Iter {it}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)

if training_callback is not None:
train_info = {
Expand All @@ -281,8 +309,9 @@ def step(batch):
}
training_callback.on_train_loss_report(train_info)

losses = []
losses = 0
n_tokens = 0
steps = 0
start = time.perf_counter()

# Save adapter weights
Expand Down
51 changes: 31 additions & 20 deletions llms/tests/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import sys
import unittest
from contextlib import contextmanager
from io import StringIO
from unittest.mock import MagicMock

Expand All @@ -17,6 +18,14 @@
from mlx_lm.tuner.utils import build_schedule


@contextmanager
def swapped_with_identity(obj, func):
old_func = getattr(obj, func)
setattr(obj, func, lambda x: x)
yield
setattr(obj, func, old_func)


class TestLora(unittest.TestCase):
def setUp(self):
self.capturedOutput = StringIO()
Expand Down Expand Up @@ -374,16 +383,17 @@ def test_evaluate_calls(self):
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
]
evaluate(
model=mock_model,
dataset=mock_dataset,
tokenizer=mock_tokenizer,
batch_size=2,
num_batches=2,
max_seq_length=2048,
loss=mock_default_loss,
iterate_batches=mock_iterate_batches,
)
with swapped_with_identity(mx.distributed, "all_sum"):
evaluate(
model=mock_model,
dataset=mock_dataset,
tokenizer=mock_tokenizer,
batch_size=2,
num_batches=2,
max_seq_length=2048,
loss=mock_default_loss,
iterate_batches=mock_iterate_batches,
)

mock_iterate_batches.assert_called_once_with(
dataset=mock_dataset,
Expand Down Expand Up @@ -412,16 +422,17 @@ def test_evaluate_infinite_batches(self):
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
]

evaluate(
model=mock_model,
dataset=mock_dataset,
tokenizer=mock_tokenizer,
batch_size=2,
num_batches=-1,
max_seq_length=2048,
loss=mock_default_loss,
iterate_batches=mock_iterate_batches,
)
with swapped_with_identity(mx.distributed, "all_sum"):
evaluate(
model=mock_model,
dataset=mock_dataset,
tokenizer=mock_tokenizer,
batch_size=2,
num_batches=-1,
max_seq_length=2048,
loss=mock_default_loss,
iterate_batches=mock_iterate_batches,
)

mock_iterate_batches.assert_called_once_with(
dataset=mock_dataset,
Expand Down

0 comments on commit 331148d

Please sign in to comment.