Skip to content

Commit

Permalink
fix devices during backprop
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent committed Sep 6, 2024
1 parent 213bff9 commit 5fc3f0e
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions hivemind/moe/client/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,18 @@ def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor]

num_samples, max_experts = dummy_grad_mask.shape

inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
alive_ii_cpu = alive_ii.cpu()
alive_jj_cpu = alive_jj.cpu()
inputs_per_expert = zip(*(tensor[alive_ii_cpu].split(1, dim=0) for tensor in flat_inputs_cpu))
grad_outputs_per_expert = zip(
*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu)
*(tensor[alive_ii_cpu, alive_jj_cpu].split(1, dim=0) for tensor in flat_grad_outputs_cpu)
)
backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))

# dispatch tasks to all remote experts, collect responses
pending_tasks = {}
for i, j, inputs_ij, grad_outputs_ij in zip(
alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
alive_ii_cpu.numpy(), alive_jj_cpu.numpy(), inputs_per_expert, grad_outputs_per_expert
):
expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
stub = get_server_stub(expert.p2p, expert.peer_id)
Expand Down

0 comments on commit 5fc3f0e

Please sign in to comment.