Skip to content

Commit

Permalink
support the backward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jan 30, 2025
1 parent d7bf8be commit 3803b19
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def build_grad_buffers(
return fp32_grad_buffers, contiguous_buffer_f32_gradients

def backward(self, loss: torch.Tensor):
if isinstance(loss, tuple):
if not isinstance(loss, torch.Tensor):
assert 1 == 1
raise NotImplementedError("Not implemented yet")

Expand Down
26 changes: 26 additions & 0 deletions src/nanotron/parallel/comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Dict


class AsyncCommBucket:
"""
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: expected Variable or None (got tuple)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: expected Variable or None (got tuple)
"""

_async_op: Dict[int, "dist.Work"] = {}

@staticmethod
def add(tensor_id: int, work: "dist.Work"):
AsyncCommBucket._async_op[tensor_id] = work

@staticmethod
def get(tensor_id: int):
return AsyncCommBucket._async_op.get(tensor_id)

@staticmethod
def wait(tensor_id: int):
work = AsyncCommBucket._async_op.pop(tensor_id)
work.wait()
8 changes: 6 additions & 2 deletions src/nanotron/parallel/pipeline_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from typing import Dict, Iterable, Optional, Union

import torch
from torch import nn as torch_nn
from torch.nn.parallel import DistributedDataParallel

from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
Expand All @@ -12,8 +15,6 @@
from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.utils import ContextManagers
from torch import nn as torch_nn
from torch.nn.parallel import DistributedDataParallel

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -83,6 +84,9 @@ def backward(
if grad_accumulator is None:
sum(activations).backward()
else:
# if not isinstance(activations, torch.Tensor):
# raise NotImplementedError("Only support sum of tensors for now")

grad_accumulator.backward(sum(activations))

# TODO @nouamane: this fixes interleaved afab but makes 1f1b hang
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from nanotron import distributed as dist
from nanotron.distributed import ProcessGroup
from nanotron.parallel.comm import AsyncCommBucket


class DifferentiableIdentity(torch.autograd.Function):
Expand All @@ -42,14 +43,29 @@ class DifferentiableAllReduceSum(torch.autograd.Function):
def forward(
ctx, tensor, group: Optional[ProcessGroup], async_all_reduce: bool
) -> Tuple[torch.Tensor, Optional["dist.Work"]]:
# ctx.mark_non_differentiable(async_all_reduce)
ctx.async_all_reduce = async_all_reduce

if group.size() == 1:
return tensor

orig_id = id(tensor)
handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_all_reduce)
# if async_all_reduce:
# handle.wait()
new_id = id(tensor)
assert 1 == 1
assert orig_id == new_id
# if async_all_reduce:
# return tensor, handle
# else:
# return tensor, None
if async_all_reduce:
return tensor, handle
else:
return tensor, None
# AsyncCommBucket.add(tensor, handle)
# AsyncCommBucket.add(id(tensor), handle)
AsyncCommBucket.add(orig_id, handle)

return tensor

@staticmethod
def backward(ctx, grad_output):
Expand Down
10 changes: 9 additions & 1 deletion src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,15 @@ def row_linear(
out = F.linear(input, weight, bias)

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
# out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
orig_out_id = id(out)
# NOTE: why the id(out) doesn't match the id(out) before the all_reduce?
out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
if async_all_reduce:
from nanotron.parallel.comm import AsyncCommBucket

work = AsyncCommBucket.get(orig_out_id)
assert 1 == 1
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
assert async_all_reduce is False, "Async communication is not supported for REDUCE_SCATTER mode."
out = differentiable_reduce_scatter_sum(out, group=group)
Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/parallel/tensor_parallel/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
out = out * (~input_mask[..., None])

if self.mode is TensorParallelLinearMode.ALL_REDUCE:
out, _ = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False)
out = differentiable_all_reduce_sum(out, group=self.pg, async_all_reduce=False)
elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
out = differentiable_reduce_scatter_sum(out, group=self.pg)
else:
Expand Down

0 comments on commit 3803b19

Please sign in to comment.